<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://wilsonwongso.dev/feed.xml" rel="self" type="application/atom+xml" /><link href="https://wilsonwongso.dev/" rel="alternate" type="text/html" /><updated>2026-05-12T14:40:39+10:00</updated><id>https://wilsonwongso.dev/feed.xml</id><title type="html">Wilson’s Homepage</title><subtitle>Computer Science PhD Candidate.</subtitle><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><entry><title type="html">Predicting Phonemes with BERT</title><link href="https://wilsonwongso.dev/posts/2022/04/predicting-phonemes-with-bert/" rel="alternate" type="text/html" title="Predicting Phonemes with BERT" /><published>2022-04-23T00:00:00+10:00</published><updated>2022-04-23T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2022/04/predicting-phonemes-with-bert</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2022/04/predicting-phonemes-with-bert/"><![CDATA[<p>Our team at <a href="https://www.bookbotkids.com/">Bookbot</a> is currently developing a grapheme-to-phoneme Python package for Bahasa Indonesia. The package is highly inspired by its English counterpart, <a href="https://github.com/Kyubyong/g2p">g2p</a>. A lot of our design and methods are borrowed from that library, most notably the steps to predict phonemes. The English g2p used the following algorithm (c.f. g2p’s <a href="https://github.com/Kyubyong/g2p#algorithm">README</a>):</p>

<ol>
  <li>Spells out arabic numbers and some currency symbols. (e.g. $200 -&gt; two hundred dollars) (This is borrowed from Keith Ito’s code)</li>
  <li>Attempts to retrieve the correct pronunciation for heteronyms based on their POS)</li>
  <li>Looks up The CMU Pronouncing Dictionary for non-homographs.</li>
  <li>For OOVs, we predict their pronunciations using our neural net model.</li>
</ol>

<p>Steps 1-3 are particularly easier to develop, granted that we were able to find an online Bahasa Indonesia lexicon from <a href="https://github.com/open-dict-data/ipa-dict/blob/master/data/ma.txt">ipa-dict</a>. Step 4 however, was particularly challenging. Authors of g2p used a recurrent, sequence2sequence <a href="https://arxiv.org/abs/1409.1259">GRU</a> that takes in graphemes as inputs and outputs phonemes. This approach is particularly useful because we would not need to determine the rules of conversion by hand. The neural net would do the heavy lifting prediction for us for unseen words.</p>

<p>Seeing their success, we attempted a similar approach. That is, we trained a recurrent sequence2sequence <a href="https://doi.org/10.1162/neco.1997.9.8.1735">LSTM</a> on the aforementioned lexicon, which you can find <a href="https://huggingface.co/bookbot/id-g2p-lstm">here</a>. As expected, the model worked great for words that are relatively simple and words whose sub-words may have been in the training set. It also achieved a validation accuracy of over 97% – and so we thought it would suffice.</p>

<p>We then converted the model to ONNX for deployment purposes and soon ended up with a working prototype g2p library, using the exact same approach as the English g2p. Upon further playing around, we quickly found an issue with the seq2seq approach. Though it performed well on the held-out validation set, it quickly crumbled when given strikingly different words, for instance names of people or names of a place. On the one hand, this is not surprising given that its training data is relatively small. But we thought we could do better.</p>

<p>First, we realized that phonemes <strong>in the <a href="https://en.wikipedia.org/wiki/International_Phonetic_Alphabet">IPA</a> format that our data was in</strong> was not too different from their corresponding graphemes. For instance, here are a few examples:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">sampingnya</code> = <code class="language-plaintext highlighter-rouge">sampiŋɲa</code></li>
  <li><code class="language-plaintext highlighter-rouge">tayangan</code> = <code class="language-plaintext highlighter-rouge">tajaŋan</code></li>
  <li><code class="language-plaintext highlighter-rouge">bepercikan</code> = <code class="language-plaintext highlighter-rouge">bəpərtʃikan</code></li>
  <li><code class="language-plaintext highlighter-rouge">deduktif</code> = <code class="language-plaintext highlighter-rouge">deduʔtif</code></li>
</ul>

<p>You may notice that there are simple mapping rules that we could infer by hand. Indeed, we found the following rules to be sufficient</p>

<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">PHONETIC_MAPPING</span> <span class="o">=</span> <span class="p">{</span>
    <span class="s">"ny"</span><span class="p">:</span> <span class="s">"ɲ"</span><span class="p">,</span>
    <span class="s">"ng"</span><span class="p">:</span> <span class="s">"ŋ"</span><span class="p">,</span>
    <span class="s">"c"</span><span class="p">:</span> <span class="s">"tʃ"</span><span class="p">,</span>
    <span class="s">"'"</span><span class="p">:</span> <span class="s">"ʔ"</span><span class="p">,</span>
    <span class="s">"aa"</span><span class="p">:</span> <span class="s">"aʔa"</span><span class="p">,</span>
    <span class="s">"ii"</span><span class="p">:</span> <span class="s">"iʔi"</span><span class="p">,</span>
    <span class="s">"oo"</span><span class="p">:</span> <span class="s">"oʔo"</span><span class="p">,</span>
    <span class="s">"əə"</span><span class="p">:</span> <span class="s">"əʔə"</span><span class="p">,</span>
    <span class="s">"j"</span><span class="p">:</span> <span class="s">"dʒ"</span><span class="p">,</span>
    <span class="s">"y"</span><span class="p">:</span> <span class="s">"j"</span><span class="p">,</span>
    <span class="s">"q"</span><span class="p">:</span> <span class="s">"k"</span>
<span class="p">}</span>

<span class="n">CONSONANTS</span> <span class="o">=</span> <span class="s">"bdfghjklmnprstvwxɲ"</span>

<span class="k">def</span> <span class="nf">g2p</span><span class="p">(</span><span class="n">text</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">text</span><span class="p">.</span><span class="n">endswith</span><span class="p">(</span><span class="s">"k"</span><span class="p">):</span>
        <span class="n">text</span> <span class="o">=</span> <span class="n">text</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="s">"ʔ"</span>

    <span class="k">for</span> <span class="n">g</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">PHONETIC_MAPPING</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
        <span class="n">text</span> <span class="o">=</span> <span class="n">text</span><span class="p">.</span><span class="n">replace</span><span class="p">(</span><span class="n">g</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">CONSONANTS</span><span class="p">:</span>
        <span class="n">text</span> <span class="o">=</span> <span class="n">text</span><span class="p">.</span><span class="n">replace</span><span class="p">(</span><span class="sa">f</span><span class="s">"k</span><span class="si">{</span><span class="n">c</span><span class="si">}</span><span class="s">"</span><span class="p">,</span> <span class="sa">f</span><span class="s">"ʔ</span><span class="si">{</span><span class="n">c</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">text</span>
</code></pre></div></div>

<p>The code is written in Python, with very basic <em>if-this-then-that</em> rules. This approach made a lot of sense, given that changes from a grapheme to an IPA phoneme aren’t too drastic, at least in our case. A sequence2sequence model could definitely do the same, but it would probably need a larger and more diverse dataset for training.</p>

<p>But, that doesn’t mean that the English g2p approach using a GRU was ineffective! Notice that their phoneme is of the <a href="https://en.wikipedia.org/wiki/ARPABET">ARPAbet</a> format, which is significantly more complicated than the IPA format we used. Their approach made complete sense because of the change in text domains. This is the same reason why translation tasks are better of using a sequence2sequence neural net over hand-written rules. It would take ages, if not impossible, to code up all rules of translation between 2 languages, but a recurrent model like GRU could automatically learn this “hidden translation rule” if there was one.</p>

<h1 id="a-problem-with-the-letter-e">A problem with the letter E</h1>

<p>But there was a huge issue with the rule-based approach we took. That is, there are 3 ways to pronounce the letter <code class="language-plaintext highlighter-rouge">e</code> in Indonesian, according to <a href="https://ivanlanin.github.io/puebi/huruf/huruf-vokal/">KBBI</a>. The lexicon that we used further limited the pronunciation to only two ways: a closed-mid front unrounded vowel <code class="language-plaintext highlighter-rouge">e</code> or a mid central vowel <code class="language-plaintext highlighter-rouge">ə</code>. For example, the word <code class="language-plaintext highlighter-rouge">bebek</code> (meaning: duck) has the phoneme <code class="language-plaintext highlighter-rouge">bebek</code>, while the word <code class="language-plaintext highlighter-rouge">delapan</code> (meaning: eight) has the phoneme <code class="language-plaintext highlighter-rouge">dəlapan</code>. Sometimes, a word might have &gt;1 <code class="language-plaintext highlighter-rouge">e</code>’s pronounced in both ways, like the word <code class="language-plaintext highlighter-rouge">mereka</code> (meaning: they) that is pronounced as <code class="language-plaintext highlighter-rouge">məreka</code>. You can hear how they sound through the Google Translate TTS <a href="https://translate.google.com/?sl=id&amp;tl=en&amp;text=bebek&amp;op=translate">here</a>, <a href="https://translate.google.com/?sl=id&amp;tl=en&amp;text=delapan&amp;op=translate">here</a>, and <a href="https://translate.google.com/?sl=id&amp;tl=en&amp;text=mereka&amp;op=translate">here</a>.</p>

<p>To the best of our knowledge, there isn’t a linguistic rule to determine exactly how a particular <code class="language-plaintext highlighter-rouge">e</code> should sound like. KBBI might have phonetic assistance for this purpose, particularly homographs. Non-homographs, however, do not have phonetic assistance. I personally think that this is a huge problem, especially for new learners of the language. Native speakers like me would find this distinction of <code class="language-plaintext highlighter-rouge">e</code>’s as natural, but I can’t imagine being in the shoes of someone learning the language.</p>

<p>To be fair, the Indonesian language isn’t like the English language where there are “native speakers” to whom we can consult. The Indonesian language is a lingua franca, a standardized version of Malay, and was largely influenced by Dutch and tons of other regional languages such as Javanese, Sundanese, etc. There might not necessarily be a definitive “correct” way to pronounce the letter <code class="language-plaintext highlighter-rouge">e</code> of a given word, because in order to do so, we need to consult the origin of the word. Furthermore, different regions of Indonesia may pronounce the same word differently, due to their dialect. You can read more about this <a href="https://id.quora.com/Mengapa-terdapat-perbedaan-pelafalan-huruf-E-dalam-beberapa-kata-yang-berbahasa-Indonesia-Contohnya-bendera-dan-benderang/answer/Benny-Lin">here</a> and here <a href="https://id.quora.com/Mengapa-terdapat-perbedaan-pelafalan-huruf-E-dalam-beberapa-kata-yang-berbahasa-Indonesia-Contohnya-bendera-dan-benderang/answer/Gladhys-Elliona-Syahutari">here</a>. Both discussions are in Indonesian, but Google Translate should do the job.</p>

<p>In any case, our g2p package needs a way to distinguish <code class="language-plaintext highlighter-rouge">e</code>’s from <code class="language-plaintext highlighter-rouge">ə</code>’s. Once that distinction has been made, we can simply pass it to the hand-written g2p algorithm that does the rest of the job.</p>

<h1 id="formulating-the-problem">Formulating the Problem</h1>

<p>At first, we thought a sequence2sequence can do the job just fine. We can simply train on pairs of data like:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">bebek</code> &amp; <code class="language-plaintext highlighter-rouge">bebek</code></li>
  <li><code class="language-plaintext highlighter-rouge">delapan</code> &amp; <code class="language-plaintext highlighter-rouge">dəlapan</code></li>
  <li><code class="language-plaintext highlighter-rouge">mereka</code> &amp; <code class="language-plaintext highlighter-rouge">məreka</code></li>
</ul>

<p>and then simply pass their output to the hand-written g2p rule. But after more thinking, we recalled the pitfalls of this method and thought that it would suffer from the same issues. Bad OOV performance, incorrect output length, etc. And so we re-formulated the problem differently.</p>

<p>Instead of treating the phonetic prediction as a generation problem, why not treat it as a de-masking problem? That is, instead of training an autoregressive model like an LSTM, why not train an autoencoder model like <a href="https://arxiv.org/abs/1810.04805">BERT</a> instead?</p>

<p>Normally, a BERT model is trained as a word-level masked language model; think fill in the blanks problem. Given the context:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>The weather is good today, the ___ is bright and blue.
</code></pre></div></div>

<p>or</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Have a ____ and relax.
</code></pre></div></div>

<p>You can probably infer what those blanks should be. And that is exactly how BERT is trained. It sees the neighbors of the masked (emptied) word, and makes a prediction based on them. Realizing this, I saw a very intruiging possibility to implement the same mechanics for our problem with the letter <code class="language-plaintext highlighter-rouge">e</code>. That is, frame the problem as:</p>

<ul>
  <li>Context: <code class="language-plaintext highlighter-rouge">b _ b _ k</code>, Output: <code class="language-plaintext highlighter-rouge">b e b e k</code></li>
  <li>Context: <code class="language-plaintext highlighter-rouge">d _ l a p a n</code>, Output: <code class="language-plaintext highlighter-rouge">d ə l a p a n</code></li>
  <li>Context: <code class="language-plaintext highlighter-rouge">m _ r _ k a</code>, Output: <code class="language-plaintext highlighter-rouge">m ə r e k a</code></li>
</ul>

<p>and so on. The hope is that, given the neighbouring letters, the BERT model will be able to infer the right phoneme of <code class="language-plaintext highlighter-rouge">e</code> to use.</p>

<p>Per my research, I have not found someone else using the same approach. I don’t know if the idea is merely bad on paper, so I gave it a try because, why not?</p>

<h1 id="code">Code</h1>

<h2 id="dataset">Dataset</h2>

<p>This is the training dataset that I ended up with. But recall, we need to mask out the <code class="language-plaintext highlighter-rouge">e</code>’s later and let the model predict the suitable phonetic <code class="language-plaintext highlighter-rouge">e</code>. Again, this dataset originates from the <a href="https://github.com/open-dict-data/ipa-dict/blob/master/data/ma.txt">ipa-dict</a> which we pre-processed and modified. You can find our version <a href="https://huggingface.co/datasets/bookbot/id_word2phoneme">here</a>.</p>

<table>
  <thead>
    <tr>
      <th> </th>
      <th>word</th>
      <th>target</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>- - n y a</td>
      <td>- - n y a</td>
    </tr>
    <tr>
      <td>1</td>
      <td>- a n d a</td>
      <td>- a n d a</td>
    </tr>
    <tr>
      <td>2</td>
      <td>- b a u r</td>
      <td>- b a u r</td>
    </tr>
    <tr>
      <td>3</td>
      <td>- b e l a s</td>
      <td>- b ə l a s</td>
    </tr>
    <tr>
      <td>4</td>
      <td>- c o m p e n g</td>
      <td>- c o m p e n g</td>
    </tr>
    <tr>
      <td>…</td>
      <td>…</td>
      <td>…</td>
    </tr>
    <tr>
      <td>27547</td>
      <td>z o h o r</td>
      <td>z o h o r</td>
    </tr>
    <tr>
      <td>27548</td>
      <td>z o n a</td>
      <td>z o n a</td>
    </tr>
    <tr>
      <td>27549</td>
      <td>z u h u r</td>
      <td>z u h u r</td>
    </tr>
    <tr>
      <td>27550</td>
      <td>z u l k a r n a i n</td>
      <td>z u l k a r n a i n</td>
    </tr>
    <tr>
      <td>27551</td>
      <td>z u r i a t</td>
      <td>z u r i a t</td>
    </tr>
  </tbody>
</table>

<h2 id="character-level-masked-language-model">Character-Level Masked Language Model</h2>

<p>Now, I have never written a BERT Masked Language Model from scratch, so I followed a very nice guide from <a href="https://keras.io/examples/nlp/masked_language_modeling/">Keras</a>, written by <a href="https://twitter.com/ankur310794">Ankur Singh</a>. It’s very clear and easily customizable to our use case, so I went with it.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span>
<span class="kn">from</span> <span class="nn">tensorflow.keras</span> <span class="kn">import</span> <span class="n">layers</span>
<span class="kn">from</span> <span class="nn">tensorflow.keras.layers</span> <span class="kn">import</span> <span class="n">TextVectorization</span>
<span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">Config</span><span class="p">:</span>
    <span class="n">MAX_LEN</span> <span class="o">=</span> <span class="mi">32</span>
    <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">128</span>
    <span class="n">LR</span> <span class="o">=</span> <span class="mf">0.001</span>
    <span class="n">VOCAB_SIZE</span> <span class="o">=</span> <span class="mi">32</span>
    <span class="n">EMBED_DIM</span> <span class="o">=</span> <span class="mi">128</span>
    <span class="n">NUM_HEAD</span> <span class="o">=</span> <span class="mi">8</span>
    <span class="n">FF_DIM</span> <span class="o">=</span> <span class="mi">128</span>
    <span class="n">NUM_LAYERS</span> <span class="o">=</span> <span class="mi">2</span>

<span class="n">config</span> <span class="o">=</span> <span class="n">Config</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="tokenization-and-preprocessing">Tokenization and Preprocessing</h3>

<p>The tutorial used a Keras <code class="language-plaintext highlighter-rouge">TextVectorization</code> layer for tokenization purposes, which I also find to be easy to use and customize. The only change I made was simplifying the text standarization function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_vectorize_layer</span><span class="p">(</span><span class="n">texts</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">max_seq</span><span class="p">,</span> <span class="n">special_tokens</span><span class="o">=</span><span class="p">[</span><span class="s">"[MASK]"</span><span class="p">]):</span>
    <span class="n">vectorize_layer</span> <span class="o">=</span> <span class="n">TextVectorization</span><span class="p">(</span>
        <span class="n">max_tokens</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
        <span class="n">output_mode</span><span class="o">=</span><span class="s">"int"</span><span class="p">,</span>
        <span class="n">standardize</span><span class="o">=</span><span class="k">lambda</span> <span class="n">input_data</span><span class="p">:</span> <span class="n">tf</span><span class="p">.</span><span class="n">strings</span><span class="p">.</span><span class="n">lower</span><span class="p">(</span><span class="n">input_data</span><span class="p">),</span>
        <span class="n">output_sequence_length</span><span class="o">=</span><span class="n">max_seq</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">vectorize_layer</span><span class="p">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">texts</span><span class="p">)</span>

    <span class="n">vocab</span> <span class="o">=</span> <span class="n">vectorize_layer</span><span class="p">.</span><span class="n">get_vocabulary</span><span class="p">()</span>

    <span class="n">vocab</span> <span class="o">=</span> <span class="n">vocab</span><span class="p">[</span><span class="mi">2</span> <span class="p">:</span> <span class="n">vocab_size</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">special_tokens</span><span class="p">)]</span> <span class="o">+</span> <span class="p">[</span><span class="s">"[mask]"</span><span class="p">]</span>
    <span class="n">vectorize_layer</span><span class="p">.</span><span class="n">set_vocabulary</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">vectorize_layer</span>

<span class="n">vectorize_layer</span> <span class="o">=</span> <span class="n">get_vectorize_layer</span><span class="p">(</span>
    <span class="n">df</span><span class="p">.</span><span class="n">target</span><span class="p">.</span><span class="n">values</span><span class="p">.</span><span class="n">tolist</span><span class="p">(),</span>
    <span class="n">config</span><span class="p">.</span><span class="n">VOCAB_SIZE</span><span class="p">,</span>
    <span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span><span class="p">,</span>
    <span class="n">special_tokens</span><span class="o">=</span><span class="p">[</span><span class="s">"[mask]"</span><span class="p">],</span>
<span class="p">)</span>
</code></pre></div></div>

<p>This is where most of the changes were made. First, instead of masking characters at random, only a “hard-mask” was applied on both <code class="language-plaintext highlighter-rouge">e</code> and <code class="language-plaintext highlighter-rouge">ə</code> tokens, completely masking them out in every text. This meant that the 15% BERT masking, 90%/10% random masking, as well as the 10% random swaps were all removed. I found that masking other characters which are not <code class="language-plaintext highlighter-rouge">e</code>’s gave worse performance. I suspect that this just made the problem even harder for the model to learn since there is very minimal context.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Get mask token id for masked language model
</span><span class="n">mask_token_id</span> <span class="o">=</span> <span class="n">vectorize_layer</span><span class="p">([</span><span class="s">"[mask]"</span><span class="p">]).</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="n">e1_token_id</span> <span class="o">=</span> <span class="n">vectorize_layer</span><span class="p">([</span><span class="s">"e"</span><span class="p">]).</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="n">e2_token_id</span> <span class="o">=</span> <span class="n">vectorize_layer</span><span class="p">([</span><span class="s">"ə"</span><span class="p">]).</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>

<span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="n">texts</span><span class="p">):</span>
    <span class="n">encoded_texts</span> <span class="o">=</span> <span class="n">vectorize_layer</span><span class="p">(</span><span class="n">texts</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">encoded_texts</span><span class="p">.</span><span class="n">numpy</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">get_masked_input_and_labels</span><span class="p">(</span><span class="n">encoded_texts</span><span class="p">):</span>
    <span class="c1"># BERT masking
</span>    <span class="n">inp_mask</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="o">*</span><span class="n">encoded_texts</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">0</span>
    <span class="c1"># Do not mask special tokens
</span>    <span class="n">inp_mask</span><span class="p">[</span><span class="n">encoded_texts</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="bp">False</span>
    <span class="c1"># Force mask e's
</span>    <span class="n">inp_mask</span><span class="p">[</span><span class="n">encoded_texts</span> <span class="o">==</span> <span class="n">e1_token_id</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
    <span class="n">inp_mask</span><span class="p">[</span><span class="n">encoded_texts</span> <span class="o">==</span> <span class="n">e2_token_id</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
    <span class="c1"># Set targets to -1 by default, it means ignore
</span>    <span class="n">labels</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">encoded_texts</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">int</span><span class="p">)</span>
    <span class="c1"># Set labels for masked tokens
</span>    <span class="n">labels</span><span class="p">[</span><span class="n">inp_mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoded_texts</span><span class="p">[</span><span class="n">inp_mask</span><span class="p">]</span>

    <span class="c1"># Prepare input
</span>    <span class="n">encoded_texts_masked</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">copy</span><span class="p">(</span><span class="n">encoded_texts</span><span class="p">)</span>
    <span class="n">encoded_texts_masked</span><span class="p">[</span><span class="n">inp_mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask_token_id</span>
    <span class="c1"># note: we don't randomly change chars and apply all masks
</span>
    <span class="c1"># Prepare sample_weights to pass to .fit() method
</span>    <span class="n">sample_weights</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">sample_weights</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="c1"># y_labels would be same as encoded_texts i.e input tokens
</span>    <span class="n">y_labels</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">copy</span><span class="p">(</span><span class="n">encoded_texts</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">encoded_texts_masked</span><span class="p">,</span> <span class="n">y_labels</span><span class="p">,</span> <span class="n">sample_weights</span>
</code></pre></div></div>

<p>Here’s an example of an input, label, and weights array, respectively. Notice that at the index of the letter <code class="language-plaintext highlighter-rouge">e</code>, the input is masked and has the mask token id of <code class="language-plaintext highlighter-rouge">30</code>, with the target token id of <code class="language-plaintext highlighter-rouge">18</code> and <code class="language-plaintext highlighter-rouge">4</code>, corresponding to <code class="language-plaintext highlighter-rouge">e</code> and <code class="language-plaintext highlighter-rouge">ə</code>, respectively. Also notice that the weights default to <code class="language-plaintext highlighter-rouge">0</code> for unmasked tokens and <code class="language-plaintext highlighter-rouge">1</code> for masked tokens. This is to facilitate training. Recall that the model will only be “graded” by its performance on the blanks.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">get_masked_input_and_labels</span><span class="p">(</span><span class="n">encode</span><span class="p">(</span><span class="s">"m e r d ə k a"</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(array([ 8, 30,  6, 16, 30,  7,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([ 8, 18,  6, 16,  4,  7,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Prepare data for masked language model
</span><span class="n">x_all</span> <span class="o">=</span> <span class="n">encode</span><span class="p">(</span><span class="n">df</span><span class="p">.</span><span class="n">target</span><span class="p">.</span><span class="n">values</span><span class="p">)</span>
<span class="n">x_masked_train</span><span class="p">,</span> <span class="n">y_masked_labels</span><span class="p">,</span> <span class="n">sample_weights</span> <span class="o">=</span> <span class="n">get_masked_input_and_labels</span><span class="p">(</span><span class="n">x_all</span><span class="p">)</span>

<span class="n">mlm_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">Dataset</span><span class="p">.</span><span class="n">from_tensor_slices</span><span class="p">(</span>
    <span class="p">(</span><span class="n">x_masked_train</span><span class="p">,</span> <span class="n">y_masked_labels</span><span class="p">,</span> <span class="n">sample_weights</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">mlm_ds</span> <span class="o">=</span> <span class="n">mlm_ds</span><span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1000</span><span class="p">).</span><span class="n">batch</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="bert">BERT</h3>

<p>There’s really no difference between the code written in the Keras guide with the one I have here. I’ll just note how elegant Keras code is for a model like BERT. But in any case, this model is exactly the same as if we were to train a word-level masked language model. This time, the input tokens are just characters instead of words. Same old objective, same architecture, and so on.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">bert_module</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
    <span class="c1"># Multi headed self-attention
</span>    <span class="n">attention_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span>
        <span class="n">num_heads</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">NUM_HEAD</span><span class="p">,</span>
        <span class="n">key_dim</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">EMBED_DIM</span> <span class="o">//</span> <span class="n">config</span><span class="p">.</span><span class="n">NUM_HEAD</span><span class="p">,</span>
        <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/multiheadattention"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">),</span>
    <span class="p">)(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
    <span class="n">attention_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/att_dropout"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">))(</span>
        <span class="n">attention_output</span>
    <span class="p">)</span>
    <span class="n">attention_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">LayerNormalization</span><span class="p">(</span>
        <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/att_layernormalization"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
    <span class="p">)(</span><span class="n">query</span> <span class="o">+</span> <span class="n">attention_output</span><span class="p">)</span>

    <span class="c1"># Feed-forward layer
</span>    <span class="n">ffn</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
        <span class="p">[</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">FF_DIM</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">"relu"</span><span class="p">),</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">EMBED_DIM</span><span class="p">),</span>
        <span class="p">],</span>
        <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/ffn"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">),</span>
    <span class="p">)</span>
    <span class="n">ffn_output</span> <span class="o">=</span> <span class="n">ffn</span><span class="p">(</span><span class="n">attention_output</span><span class="p">)</span>
    <span class="n">ffn_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/ffn_dropout"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">))(</span>
        <span class="n">ffn_output</span>
    <span class="p">)</span>
    <span class="n">sequence_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">LayerNormalization</span><span class="p">(</span>
        <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"encoder_{}/ffn_layernormalization"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
    <span class="p">)(</span><span class="n">attention_output</span> <span class="o">+</span> <span class="n">ffn_output</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">sequence_output</span>


<span class="k">def</span> <span class="nf">get_pos_encoding_matrix</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_emb</span><span class="p">):</span>
    <span class="n">pos_enc</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
        <span class="p">[</span>
            <span class="p">[</span><span class="n">pos</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="n">power</span><span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">j</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">d_emb</span><span class="p">)</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">d_emb</span><span class="p">)]</span>
            <span class="k">if</span> <span class="n">pos</span> <span class="o">!=</span> <span class="mi">0</span>
            <span class="k">else</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_emb</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">pos</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_len</span><span class="p">)</span>
        <span class="p">]</span>
    <span class="p">)</span>
    <span class="n">pos_enc</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">pos_enc</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">])</span>  <span class="c1"># dim 2i
</span>    <span class="n">pos_enc</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">pos_enc</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">])</span>  <span class="c1"># dim 2i+1
</span>    <span class="k">return</span> <span class="n">pos_enc</span>


<span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span>
    <span class="n">reduction</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">Reduction</span><span class="p">.</span><span class="n">NONE</span>
<span class="p">)</span>
<span class="n">loss_tracker</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">metrics</span><span class="p">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">"loss"</span><span class="p">)</span>


<span class="k">class</span> <span class="nc">MaskedLanguageModel</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
            <span class="n">features</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">sample_weight</span> <span class="o">=</span> <span class="n">inputs</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">features</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">inputs</span>
            <span class="n">sample_weight</span> <span class="o">=</span> <span class="bp">None</span>

        <span class="k">with</span> <span class="n">tf</span><span class="p">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
            <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">)</span>

        <span class="c1"># Compute gradients
</span>        <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">trainable_variables</span>
        <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="p">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span>

        <span class="c1"># Update weights
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span>

        <span class="c1"># Compute our own metrics
</span>        <span class="n">loss_tracker</span><span class="p">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">)</span>

        <span class="c1"># Return a dict mapping metric names to current value
</span>        <span class="k">return</span> <span class="p">{</span><span class="s">"loss"</span><span class="p">:</span> <span class="n">loss_tracker</span><span class="p">.</span><span class="n">result</span><span class="p">()}</span>

    <span class="o">@</span><span class="nb">property</span>
    <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="p">[</span><span class="n">loss_tracker</span><span class="p">]</span>


<span class="k">def</span> <span class="nf">create_masked_language_bert_model</span><span class="p">():</span>
    <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Input</span><span class="p">((</span><span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">int64</span><span class="p">)</span>

    <span class="n">word_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Embedding</span><span class="p">(</span>
        <span class="n">config</span><span class="p">.</span><span class="n">VOCAB_SIZE</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">EMBED_DIM</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"word_embedding"</span>
    <span class="p">)(</span><span class="n">inputs</span><span class="p">)</span>
    <span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Embedding</span><span class="p">(</span>
        <span class="n">input_dim</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span><span class="p">,</span>
        <span class="n">output_dim</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">EMBED_DIM</span><span class="p">,</span>
        <span class="n">weights</span><span class="o">=</span><span class="p">[</span><span class="n">get_pos_encoding_matrix</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span><span class="p">,</span> <span class="n">config</span><span class="p">.</span><span class="n">EMBED_DIM</span><span class="p">)],</span>
        <span class="n">name</span><span class="o">=</span><span class="s">"position_embedding"</span><span class="p">,</span>
    <span class="p">)(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">limit</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span><span class="p">,</span> <span class="n">delta</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">embeddings</span> <span class="o">=</span> <span class="n">word_embeddings</span> <span class="o">+</span> <span class="n">position_embeddings</span>

    <span class="n">encoder_output</span> <span class="o">=</span> <span class="n">embeddings</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">NUM_LAYERS</span><span class="p">):</span>
        <span class="n">encoder_output</span> <span class="o">=</span> <span class="n">bert_module</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>

    <span class="n">mlm_output</span> <span class="o">=</span> <span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">VOCAB_SIZE</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"mlm_cls"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">"softmax"</span><span class="p">)(</span>
        <span class="n">encoder_output</span>
    <span class="p">)</span>
    <span class="n">mlm_model</span> <span class="o">=</span> <span class="n">MaskedLanguageModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">mlm_output</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"masked_bert_model"</span><span class="p">)</span>

    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">config</span><span class="p">.</span><span class="n">LR</span><span class="p">)</span>
    <span class="n">mlm_model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">mlm_model</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">id2token</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">vectorize_layer</span><span class="p">.</span><span class="n">get_vocabulary</span><span class="p">()))</span>
<span class="n">token2id</span> <span class="o">=</span> <span class="p">{</span><span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">id2token</span><span class="p">.</span><span class="n">items</span><span class="p">()}</span>

<span class="n">bert_masked_model</span> <span class="o">=</span> <span class="n">create_masked_language_bert_model</span><span class="p">()</span>
<span class="n">bert_masked_model</span><span class="p">.</span><span class="n">summary</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model: "masked_bert_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(None, 32)]         0           []

 word_embedding (Embedding)     (None, 32, 128)      4096        ['input_1[0][0]']

 tf.__operators__.add (TFOpLamb  (None, 32, 128)     0           ['word_embedding[0][0]']
 da)

 encoder_0/multiheadattention (  (None, 32, 128)     66048       ['tf.__operators__.add[0][0]',
 MultiHeadAttention)                                              'tf.__operators__.add[0][0]',
                                                                  'tf.__operators__.add[0][0]']

 encoder_0/att_dropout (Dropout  (None, 32, 128)     0           ['encoder_0/multiheadattention[0]
 )                                                               [0]']

 tf.__operators__.add_1 (TFOpLa  (None, 32, 128)     0           ['tf.__operators__.add[0][0]',
 mbda)                                                            'encoder_0/att_dropout[0][0]']

 encoder_0/att_layernormalizati  (None, 32, 128)     256         ['tf.__operators__.add_1[0][0]']
 on (LayerNormalization)

 encoder_0/ffn (Sequential)     (None, 32, 128)      33024       ['encoder_0/att_layernormalizatio
                                                                 n[0][0]']

 encoder_0/ffn_dropout (Dropout  (None, 32, 128)     0           ['encoder_0/ffn[0][0]']
 )

 tf.__operators__.add_2 (TFOpLa  (None, 32, 128)     0           ['encoder_0/att_layernormalizatio
 mbda)                                                           n[0][0]',
                                                                  'encoder_0/ffn_dropout[0][0]']

 encoder_0/ffn_layernormalizati  (None, 32, 128)     256         ['tf.__operators__.add_2[0][0]']
 on (LayerNormalization)

 encoder_1/multiheadattention (  (None, 32, 128)     66048       ['encoder_0/ffn_layernormalizatio
 MultiHeadAttention)                                             n[0][0]',
                                                                  'encoder_0/ffn_layernormalizatio
                                                                 n[0][0]',
                                                                  'encoder_0/ffn_layernormalizatio
                                                                 n[0][0]']

 encoder_1/att_dropout (Dropout  (None, 32, 128)     0           ['encoder_1/multiheadattention[0]
 )                                                               [0]']

 tf.__operators__.add_3 (TFOpLa  (None, 32, 128)     0           ['encoder_0/ffn_layernormalizatio
 mbda)                                                           n[0][0]',
                                                                  'encoder_1/att_dropout[0][0]']

 encoder_1/att_layernormalizati  (None, 32, 128)     256         ['tf.__operators__.add_3[0][0]']
 on (LayerNormalization)

 encoder_1/ffn (Sequential)     (None, 32, 128)      33024       ['encoder_1/att_layernormalizatio
                                                                 n[0][0]']

 encoder_1/ffn_dropout (Dropout  (None, 32, 128)     0           ['encoder_1/ffn[0][0]']
 )

 tf.__operators__.add_4 (TFOpLa  (None, 32, 128)     0           ['encoder_1/att_layernormalizatio
 mbda)                                                           n[0][0]',
                                                                  'encoder_1/ffn_dropout[0][0]']

 encoder_1/ffn_layernormalizati  (None, 32, 128)     256         ['tf.__operators__.add_4[0][0]']
 on (LayerNormalization)

 mlm_cls (Dense)                (None, 32, 32)       4128        ['encoder_1/ffn_layernormalizatio
                                                                 n[0][0]']

==================================================================================================
Total params: 207,392
Trainable params: 207,392
Non-trainable params: 0
__________________________________________________________________________________________________
</code></pre></div></div>

<h3 id="train">Train!</h3>

<p>What’s left is just for us to call <code class="language-plaintext highlighter-rouge">.fit()</code>, because this is Keras. The Keras guide used the <a href="https://arxiv.org/abs/1412.6980">Adam optimizer</a>, which generally works well for language models.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bert_masked_model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span>
    <span class="n">mlm_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="p">.</span><span class="n">callbacks</span><span class="p">.</span><span class="n">TensorBoard</span><span class="p">(</span><span class="n">log_dir</span><span class="o">=</span><span class="s">"./logs"</span><span class="p">)]</span>
<span class="p">)</span>

<span class="n">bert_masked_model</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">"bert_mlm.h5"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Epoch 1/100
216/216 [==============================] - 8s 13ms/step - loss: 0.4276
Epoch 2/100
216/216 [==============================] - 3s 13ms/step - loss: 0.3865
Epoch 3/100
216/216 [==============================] - 3s 12ms/step - loss: 0.3320
Epoch 4/100
216/216 [==============================] - 3s 12ms/step - loss: 0.3048
Epoch 5/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2887
Epoch 6/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2870
Epoch 7/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2827
Epoch 8/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2795
Epoch 9/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2939
Epoch 10/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2751
Epoch 11/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2743
Epoch 12/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2678
Epoch 13/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2671
Epoch 14/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2609
Epoch 15/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2619
Epoch 16/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2681
Epoch 17/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2689
Epoch 18/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2582
Epoch 19/100
216/216 [==============================] - 4s 16ms/step - loss: 0.2526
Epoch 20/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2559
Epoch 21/100
216/216 [==============================] - 3s 14ms/step - loss: 0.2506
Epoch 22/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2548
Epoch 23/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2584
Epoch 24/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2502
Epoch 25/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2484
Epoch 26/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2448
Epoch 27/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2502
Epoch 28/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2471
Epoch 29/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2471
Epoch 30/100
216/216 [==============================] - 4s 20ms/step - loss: 0.2422
Epoch 31/100
216/216 [==============================] - 5s 22ms/step - loss: 0.2412
Epoch 32/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2398
Epoch 33/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2500
Epoch 34/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2445
Epoch 35/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2407
Epoch 36/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2376
Epoch 37/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2351
Epoch 38/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2363
Epoch 39/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2377
Epoch 40/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2351
Epoch 41/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2467
Epoch 42/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2408
Epoch 43/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2332
Epoch 44/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2355
Epoch 45/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2371
Epoch 46/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2353
Epoch 47/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2293
Epoch 48/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2270
Epoch 49/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2258
Epoch 50/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2255
Epoch 51/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2240
Epoch 52/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2309
Epoch 53/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2336
Epoch 54/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2297
Epoch 55/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2279
Epoch 56/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2245
Epoch 57/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2239
Epoch 58/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2225
Epoch 59/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2237
Epoch 60/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2213
Epoch 61/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2210
Epoch 62/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2186
Epoch 63/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2187
Epoch 64/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2191
Epoch 65/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2165
Epoch 66/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2172
Epoch 67/100
216/216 [==============================] - 5s 23ms/step - loss: 0.2182
Epoch 68/100
216/216 [==============================] - 4s 20ms/step - loss: 0.2143
Epoch 69/100
216/216 [==============================] - 5s 23ms/step - loss: 0.2171
Epoch 70/100
216/216 [==============================] - 4s 19ms/step - loss: 0.2096
Epoch 71/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2122
Epoch 72/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2169
Epoch 73/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2134
Epoch 74/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2117
Epoch 75/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2094
Epoch 76/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2123
Epoch 77/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2134
Epoch 78/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2117
Epoch 79/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2064
Epoch 80/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2111
Epoch 81/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2130
Epoch 82/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2089
Epoch 83/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2063
Epoch 84/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2042
Epoch 85/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2032
Epoch 86/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2071
Epoch 87/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2062
Epoch 88/100
216/216 [==============================] - 3s 13ms/step - loss: 0.1999
Epoch 89/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2021
Epoch 90/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2019
Epoch 91/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2056
Epoch 92/100
216/216 [==============================] - 4s 16ms/step - loss: 0.2062
Epoch 93/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2006
Epoch 94/100
216/216 [==============================] - 3s 13ms/step - loss: 0.2034
Epoch 95/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2003
Epoch 96/100
216/216 [==============================] - 3s 12ms/step - loss: 0.2005
Epoch 97/100
216/216 [==============================] - 3s 13ms/step - loss: 0.1970
Epoch 98/100
216/216 [==============================] - 3s 13ms/step - loss: 0.1951
Epoch 99/100
216/216 [==============================] - 3s 13ms/step - loss: 0.1960
Epoch 100/100
216/216 [==============================] - 4s 20ms/step - loss: 0.1991
</code></pre></div></div>

<h3 id="inference">Inference</h3>

<p>It’s also quite simple to perform inference once the model finished training. We first need to load the model and its weights.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Load pretrained bert model
</span><span class="n">mlm_model</span> <span class="o">=</span> <span class="n">keras</span><span class="p">.</span><span class="n">models</span><span class="p">.</span><span class="n">load_model</span><span class="p">(</span>
    <span class="s">"bert_mlm.h5"</span><span class="p">,</span> <span class="n">custom_objects</span><span class="o">=</span><span class="p">{</span><span class="s">"MaskedLanguageModel"</span><span class="p">:</span> <span class="n">MaskedLanguageModel</span><span class="p">}</span>
<span class="p">)</span>
</code></pre></div></div>

<p>And then write up an inference function which we can reuse later. The way it works is also quite clear. Tokenize the input tokens as integers, while masking the <code class="language-plaintext highlighter-rouge">e</code>’s to be predicted. Then, pad the inputs to the maximum sequence length (in our case 32) and feed the input array to the BERT model. Decoding the output involves us finding the locations of those masked inputs, finding the most probable guess, and replacing the masked tokens with that prediction. Finally, we join the tokens once in they are all assembled.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="n">sequence</span><span class="p">):</span>
    <span class="n">sequence</span> <span class="o">=</span> <span class="s">" "</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="n">c</span> <span class="k">if</span> <span class="n">c</span> <span class="o">!=</span> <span class="s">"e"</span> <span class="k">else</span> <span class="s">"[mask]"</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">sequence</span><span class="p">])</span>
    <span class="n">tokens</span> <span class="o">=</span> <span class="p">[</span><span class="n">token2id</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">sequence</span><span class="p">.</span><span class="n">split</span><span class="p">()]</span>
    <span class="n">pad</span> <span class="o">=</span> <span class="p">[</span><span class="n">token2id</span><span class="p">[</span><span class="s">""</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">config</span><span class="p">.</span><span class="n">MAX_LEN</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokens</span><span class="p">))]</span>

    <span class="n">tokens</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">+</span> <span class="n">pad</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">tokens</span><span class="p">]))</span>
    <span class="n">prediction</span> <span class="o">=</span> <span class="n">mlm_model</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>

    <span class="c1"># find masked idx token
</span>    <span class="n">masked_index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="n">mask_token_id</span><span class="p">)</span>
    <span class="n">masked_index</span> <span class="o">=</span> <span class="n">masked_index</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

    <span class="c1"># get prediction at those masked index only
</span>    <span class="n">mask_prediction</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">masked_index</span><span class="p">]</span>
    <span class="n">predicted_ids</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">mask_prediction</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># replace mask with predicted token
</span>    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">masked_index</span><span class="p">):</span>
        <span class="n">tokens</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">predicted_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>

    <span class="k">return</span> <span class="s">""</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="n">id2token</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tokens</span> <span class="k">if</span> <span class="n">t</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">inference</span><span class="p">(</span><span class="s">"menyebabkannya"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'mənyəbabkannya'
</code></pre></div></div>

<p>Not forgetting to apply the hand-written g2p rules that we came up with.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">g2p</span><span class="p">(</span><span class="n">inference</span><span class="p">(</span><span class="s">"menyebabkannya"</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'məɲəbabkanɲa'
</code></pre></div></div>

<p>And thus we are done.</p>

<p>In practice, I would convert the Keras model over to ONNX so that I can run the static model with only NumPy as a dependency instead of TensorFlow/Keras. But it’s really up to your use case.</p>

<h1 id="conclusion">Conclusion</h1>

<p>This little weekend experiment of mine is pretty much just a proof of concept, certainly with room for improvements. But at least, I’m happy that it worked better than the LSTM. It’s much more controllable and won’t be too shabby of a guess for OOV words.</p>

<p>This will be available once the g2p package we’re developing becomes open source. Hopefully it is by the time that this blog post becomes live. Otherwise, we’re still working on it :)</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Transformer" /><summary type="html"><![CDATA[Our team at Bookbot is currently developing a grapheme-to-phoneme Python package for Bahasa Indonesia. The package is highly inspired by its English counterpart, g2p. A lot of our design and methods are borrowed from that library, most notably the steps to predict phonemes. The English g2p used the following algorithm (c.f. g2p’s README):]]></summary></entry><entry><title type="html">My HuggingFace JAX Community Week Experience</title><link href="https://wilsonwongso.dev/posts/2021/07/hf-jax-week/" rel="alternate" type="text/html" title="My HuggingFace JAX Community Week Experience" /><published>2021-07-30T00:00:00+10:00</published><updated>2021-07-30T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2021/07/huggingface-jax-community-week</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2021/07/hf-jax-week/"><![CDATA[<p>On June 23, the HuggingFace team announced that they are planning to host a <a href="https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104">community week</a> together with the people from the Google Cloud team. The main gist of this event was getting everyone to learn and use HuggingFace’s newly integrated JAX framework. But aside from just learning from tutorials, we were equipped with blazing fast TPUs thanks to the amazing Google Cloud team 🤯.</p>

<p>Hearing this, I naturally gravitated to registering for the event, and so I immediately invited my good friend, <a href="https://stevenlimcorn.github.io">Steven Limcorn</a>, to join me in this event as it was a group work. We hopped into a Discord call and the brainstorm begins..</p>

<h2 id="-plan-on-paper">📝 Plan on Paper</h2>

<p>Right off the bat, we were thinking: <strong>Indonesian Language Model</strong>. Why? Because it is the model which we both had experience training and it would be fun to learn JAX in the process (since we usually work in PyTorch).</p>

<p>At the same time, we came into a dilemma. If you know a thing or two about Indonesian NLP, there are two major players in masked language modeling (MLM): <a href="https://huggingface.co/indobenchmark/indobert-base-p1">IndoNLU’s IndoBERT</a> and <a href="https://huggingface.co/indolem/indobert-base-uncased">IndoLEM’s IndoBERT</a>.</p>

<p>We thought, okay, how can we make something different? Or perhaps something better? Rambling through Indonesian datasets, the first thing that came to mind is, of course, the <a href="https://oscar-corpus.com/">OSCAR dataset</a> (16GB). But, we thought, if we wanted the model to perform better than the existing models, we should be using a larger dataset, shouldn’t we?</p>

<center>
<img src="/images/2021/07/hf-jax-week/muppets.png" style="zoom: 70%;" />
<figcaption><i>Muppets, some of which became catchy deep learning jargons.</i></figcaption>
</center>

<p>Despite the dilemma, we ended up posting the <a href="https://discuss.huggingface.co/t/pretrain-roberta-from-scratch-in-indonesian/7240">project proposal</a> on the HuggingFace forums anyway. Luckily, a day later, we got a reply from a user that suggested two alternative datasets: <a href="http://data.statmt.org/cc-100/">CC100</a> (36GB) and <a href="https://github.com/allenai/allennlp/discussions/5265">mC4</a> (230GB). So we thought, cool, let’s train <em>the</em> best Indonesian model!</p>

<h2 id="-setting-up">⚙ Setting up</h2>

<p>To kick things off, we began by setting up the TPU Virtual Machine <a href="https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm">as instructed by the HuggingFace team</a>. We found no major issues and the installation went pretty smoothly. All tokenizing and training scripts were ready, so no major code modification was needed to get the project started.</p>

<p>Fast forward, we began with training a tokenizer. <em>“So, which dataset will we use?”</em>, I asked. OSCAR? CC100? mC4? Heck, if we want to train <em>the</em> best model, why not use the largest dataset? And so I trained a <code class="language-plaintext highlighter-rouge">ByteLevelBPETokenizer</code> on the Indonesian mC4 subset, which took a good hour or two.. Fast forward, the tokenization finished and we’re ready to train! Or are we…?</p>

<p>Being naive, I naturally ran the training script and boy was I wrong. It took ages for the gigantic mC4 dataset just to pre-process; I was impatient. <em>“Since we’re only creating a “trial” <code class="language-plaintext highlighter-rouge">RoBERTa</code> Base model, why bother training it on a huge dataset?”</em> I thought.</p>

<p>And so we took the step back and trained on OSCAR instead. Being 93% smaller in size, it took only a couple of minutes to train a tokenizer. Likewise, the pre-processing step only took a while before the model actually began training.</p>

<h2 id="️-roaming-to-thai-nlp">🚶‍♂️ Roaming to Thai NLP</h2>

<p>It was at this point that I left my computer to get a haircut (true story). While the model kept training and my hair was being cut, I paused and thought: <em>“why not participate in another project?”</em> since others were participating in &gt;1 project at the same time.</p>

<p>Returning to my computer after the haircut (and shower) ended, I browsed through existing project proposals and found a less crowded one: <strong>Thai RoBERTa project</strong>. Cool! Why not join another project that similarly works on a low-resource language? Perhaps I can learn a thing or two from it…</p>

<p>And so I contacted the participant who’s responsible for the project: <a href="https://github.com/sakares">Sakares Saengkew</a>. We talked, exchanged ideas, and ultimately agreed to work on this project together. What I didn’t expect was becoming really good friends with someone whom I have never met in person, let alone be based in Thailand 😆.</p>

<p>The more we talked, the more our friendship bonded. Along the lines of conversation, we found out that we both enjoyed watching <strong>Dota 2</strong>, so that became the topic of our conversation for a good while 😂. Games aside, Sakares’ original plan kept going, though with some hurdles along the way.</p>

<h2 id="-debugging">🤔 Debugging</h2>

<p>As for my Indonesian model, it kept training and training, with an estimated training time of about 18 hours. At first, we were so happy that training <em>actually</em> took off. Mind you, we were Linux and machine learning amateurs, so getting things off the ground was already satisfying!</p>

<p>Both the training and evaluation loss seems to be decreasing well, with the accuracy attaining decent results in the first few epochs <em>“all is well”</em>, we thought.</p>

<p>With some more hours to kill, I decided to train another language model still using HuggingFace’s JAX framework, but this time on a personal Google Colab notebook. It was trained on the very low-resource language of Sundanese and the training and evaluation loss decreased just fine. It also achieved a decent accuracy, but something odd came to my realization…</p>

<p>Despite the accuracy reports, the language model was spitting out jibberish, unreflective of the results. <em>“Maybe something is wrong?”</em>, I thought. Indeed, I had trouble converting the JAX model to PyTorch, due to the usage of FP16. <em>“Aha!”</em>, that’s where I thought the problem lies.</p>

<p>And so I opened a Github Issue on the matter, to which HuggingFace’s <a href="https://github.com/patrickvonplaten">Patrick von Platen</a> responded quickly and professionally. Apparently, a “reverse-trick” which I attempted to do to convert FP16 JAX models to FP32 was indeed the fix my model needed. What about the model’s results, though? It remained jibberish, sadly.</p>

<p>At this point, I thought I did something wrong along the training pipeline. <em>“Whatever”</em>, I said to myself, let’s just focus on the main dish: the Indonesian model.</p>

<h2 id="-not-one-but-two">✌ Not One, But Two</h2>

<p>Seeing my Indonesian model training just fine, I wanted to test its intermediate results after training it for about 6-8 hours. Pulled the model weights from the HuggingFace Hub, and voila, <strong>jibberish output</strong>! The beast which I expected to have trained is no different from its Sundanese counterpart 😓.</p>

<center>
<img src="/images/2021/07/hf-jax-week/bert-stare.gif" style="zoom: 100%;" />
<figcaption><i>Bert's stare, just like mine.</i></figcaption>
</center>

<p>Now I’m left with two problems instead of one. Badly trained models, but why? Naturally, I investigated the common ground between these two models: JAX and OSCAR dataset. The former seemed innocent though, since nobody has reported a problem with it, and I’m sure the HuggingFace team has checked the framework thoroughly…</p>

<p><em>“It must be the dataset!”</em>, I thought. But wait, while I dug through issues in HuggingFace’s Github repo, I found someone who’s facing a similar problem as I am: <a href="https://github.com/BirgerMoell">Birger Moëll</a>. Like Birger, our models were spitting jibberish despite a decent training result. Eliminating the possible causes, we suspect that it is the dataset who’s the culprit of it all, or is it?</p>

<p>We had a short interaction within the <a href="https://github.com/huggingface/transformers/issues/12554">Github Issue</a> which Birger raised, but it translated to an even longer conversation back in the official Slack channel. We exchanged dataset cleaning ideas and discussed our plans for this event. What we didn’t realize is that we’re becoming good friends from this exchange.</p>

<h2 id="-failure-and-fortune">💀/💸 Failure and Fortune</h2>

<p>The event lasted for two weeks and there are countless lessons I learned along the way from the people of HuggingFace, the people I met along the way, and of course, training the model itself. We all had a happy ending with our model training at the end of the day, but it wasn’t smooth like many of us expected, or at least I did.</p>

<p>For instance, the <a href="https://huggingface.co/flax-community/indonesian-roberta-base">Indonesian RoBERTa Base model</a> turned out to be just fine. I pushed the final version after the entire training finished, converted the model to PyTorch, and somehow it wasn’t outputting jibberish?! All along, I could have possibly pulled the first epoch of the model, or maybe even epoch zero judging from the performance 🤦‍♂️.</p>

<p>I was so close to giving up, but seeing the Base model working just as intended, I was back on track and became motivated to work on this project once again. The next boss to conquer: <code class="language-plaintext highlighter-rouge">RoBERTa</code> Large.</p>

<p>I naively thought training the Large rendition would be as trivial as training the Base model. But it turned out to be even more frustrating than the first attempt… Why? Well, unlike the Base model, the <code class="language-plaintext highlighter-rouge">RoBERTa</code> Large didn’t like the same value of learning rate. The training loss fluctuated constantly, leading me to think that it is overshooting due to a learning rate that’s too high (I was using <code class="language-plaintext highlighter-rouge">2e-4</code>).</p>

<p>And thus I decided to decrease it by about an order of magnitude (<code class="language-plaintext highlighter-rouge">2e-5</code>). It was, unsurprisingly, too low of a learning rate.. Even from the first few epochs, I can see that the model is not learning. Killed the process and increased the learning rate to <code class="language-plaintext highlighter-rouge">7e-5</code>. At that point, it was about midnight, so I crossed my fingers and went to sleep. I woke up excited the next day, and just like that, it still didn’t learn 😤. Not a lucky number 7 after all…</p>

<center>
<img src="/images/2021/07/hf-jax-week/roberta-large-training-loss.png" style="zoom: 70%;" />
<figcaption><i>Training loss of RoBERTa Large.</i></figcaption>
</center>

<p>Seeing how my time was running out (the model took ~2.5 days to train), I increased the learning rate just slightly this time (<code class="language-plaintext highlighter-rouge">8e-5</code>). Crossed my fingers yet again and left the model to train…</p>

<p>As it resumed training, I was delighted to hear that my friends were finding the light at the end of their tunnels as well. Birger’s models began to return more sensible outputs, and we found out that Sakares’ slightly incorrect tokenization scheme was the culprit of the slow model training.</p>

<p>As for myself, I was honestly disappointed to see the Large model still suffering from the same issue of “not learning” as the evaluation loss looked somewhat flat at first. But talking to my teammate Steven, he suggested that we leave it as is this time and see how it will fare, since we’re really out of time at this point.</p>

<p>To my surprise, it finally learned! After about three/four epochs (~20 hours), the evaluation loss began to decrease! I can finally sleep without having anxiety about model training, for the least. We quickly realized that with the epochs we set, it was impossible for it to zip to a very high accuracy as we wanted. But either way, it served as a lesson of learning-rate tuning and taught us that a scheduler’s warmup steps are equally as important.</p>

<h2 id="-extension--hope">😮 Extension == Hope</h2>

<p>Sometime later, the HuggingFace team announced that they will be extending the TPU access (and hence the event) for several more days than the initial deadline. For me, this meant more exploring and maximizing the tools at hand.</p>

<p>I decided to hop on the Thai NLP <em>train</em> and <em>train</em> a very trivial <a href="https://huggingface.co/flax-community/gpt2-base-thai">Thai GPT-2</a> on the OSCAR dataset. Since I had only about a day at most, I could only train for very little epochs and left the model as I slept the night. To my surprise, it actually trained well?! The evaluation loss decreased as expected, and the predictions are relatively decent for the short window of time!</p>

<p>I immediately notified and told Sakares to play around with the model as I barely understood Thai. And indeed, the model’s predictions were reflective of the training metrics reported.</p>

<p>What’s unfortunate is that our original plan of training a Thai RoBERTa was too late for the deadline. Regardless, Sakares said that it’s okay since it could still be trained using Colab Pro, if we wanted to.</p>

<center>
<img src="/images/2021/07/hf-jax-week/roberta-indonesian-demo.png" style="zoom: 70%;" />
<figcaption><i>A preview of our Indonesian RoBERTa model demo.</i></figcaption>
</center>

<p>As the HuggingFace team announced their beta feature Spaces, my team and I began ideating for a demo of our trained models. We fine-tuned the Indonesian RoBERTa base models to existing downstream tasks from IndoNLU, including <a href="https://huggingface.co/StevenLimcorn/indonesian-roberta-base-emotion-classifier">emotion classifier</a>, <a href="https://huggingface.co/w11wo/indonesian-roberta-base-sentiment-classifier">sentiment analysis</a>, and <a href="https://huggingface.co/w11wo/indonesian-roberta-base-posp-tagger">part-of-speech (POS) tagging</a>. We used the first two in our <a href="https://huggingface.co/spaces/flax-community/roberta-indonesian">model demo</a>, as well as the pre-trained masked language model itself.</p>

<p>As for my Thai NLP project with Sakares, we ended up scavenging the last-minute <a href="https://huggingface.co/spaces/flax-community/gpt2-thai/">Thai GPT-2 for model demo</a> 😂. Birger similarly deployed various models into one awesome demo titled <a href="https://huggingface.co/spaces/birgermoell/language-explorer">Language Explorer</a>. In the end, we really found the light at the end of our tunnels.</p>

<h2 id="-closing-thoughts">🚀 Closing Thoughts</h2>

<p>Although none of us managed to secure the top-15 projects, the virtual event was nonetheless a memorable one. I learned a ton from the people I met, and ultimately had fun participating in my first-ever online community event hosted by HuggingFace and Google Cloud. I cannot thank my friends and organizers enough for making this experience possible. And I cannot wait to join the next HuggingFace community event 🤗.</p>

<p>To Steven, Sakares, Birger, and the friendliest team behind HuggingFace &amp; JAX, thank you.</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Transformer" /><category term="Community" /><summary type="html"><![CDATA[On June 23, the HuggingFace team announced that they are planning to host a community week together with the people from the Google Cloud team. The main gist of this event was getting everyone to learn and use HuggingFace’s newly integrated JAX framework. But aside from just learning from tutorials, we were equipped with blazing fast TPUs thanks to the amazing Google Cloud team 🤯.]]></summary></entry><entry><title type="html">Pneumonia Chest X-Ray Classification</title><link href="https://wilsonwongso.dev/posts/2020/08/pneumonia-chest-xray-classification/" rel="alternate" type="text/html" title="Pneumonia Chest X-Ray Classification" /><published>2020-08-31T00:00:00+10:00</published><updated>2020-08-31T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/08/pneumonia-chest-xray-classification</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/08/pneumonia-chest-xray-classification/"><![CDATA[<p>The dataset used for this task if from a <a href="https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia">Kaggle dataset</a> by Paul Mooney. It consists of two kinds of chest x-rays, those infected by pneumonia, and the other being normal. Our main goal is to distinguish which chest corresponds to pneumonia-infected ones and which aren’t. Note that the dataset is highly imbalanced, like many medical image dataset are.</p>

<h3 id="fastai2-library">Fast.ai2 Library</h3>

<p>Fast.ai has just released its version 2 framework. It is bundled with tons of old plus new shiny features which weren’t available previously such as its brand new medical applications. Although this task isn’t related to actually using the medical applications, it serves as a stepping-stone.</p>

<h3 id="fastbook">Fastbook</h3>

<p>Aside from releasing its version 2 framework, fast.ai also released a companion-book dubbed fastbook. The book is available for free in the form of <a href="https://github.com/fastai/fastbook">Jupyter notebooks</a>, but one can also purchase a print version on <a href="https://www.amazon.com/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527">Amazon</a>. More importantly, this task is applying what I’ve learned from the 7th Chapter of the book called <strong>Training a State-of-the-Art Model</strong>.</p>

<h3 id="transfer-learning">Transfer Learning</h3>

<p>Lastly, I’ve also applied Transfer Learning in this task, since I’ve seen it to perform better with it after a couple of runs. The particular model I’ll be using is EfficientNetB3A, with weights from Ross Wightman’s <strong>timm</strong> <a href="https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py">library</a>.</p>

<h2 id="code">Code</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">fastai</span>
<span class="kn">from</span> <span class="nn">fastai.vision.all</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span> <span class="nn">fastai.vision.core</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span> <span class="nn">fastai.callback</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span> <span class="nn">fastai.metrics</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">precision_score</span><span class="p">,</span> <span class="n">recall_score</span><span class="p">,</span> <span class="n">confusion_matrix</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">timm</span> <span class="kn">import</span> <span class="n">create_model</span>
</code></pre></div></div>

<h3 id="load-data">Load Data</h3>

<p>The dataset is very imbalanced. Firstly, it has more pneumonia-infected chest x-rays compared to normal ones. Regardless, I’ve tried to oversample using PyTorch’s <code class="language-plaintext highlighter-rouge">WeightedRandomSampler</code> as it didn’t show much of an improvement. Secondly, it has a very small validation dataset - 16 images in total. As such, measuring the model by its validation accuracy seems unwise.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="s">"chest_xray/chest_xray"</span><span class="p">)</span>
</code></pre></div></div>

<h4 id="data-augmentation">Data Augmentation</h4>

<p>First up in the loading process is data augmentation. This includes normalizing the images with Imagenet stats since the pretrained model also used the same stats. Moreover, I’ll apply default augmentative transforms provided by fast.ai, coupled with a randomly resized crop transform.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_tfms</span> <span class="o">=</span> <span class="p">[</span><span class="n">Normalize</span><span class="p">.</span><span class="n">from_stats</span><span class="p">(</span><span class="o">*</span><span class="n">imagenet_stats</span><span class="p">),</span> <span class="o">*</span><span class="n">aug_transforms</span><span class="p">()]</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_dls</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">size</span><span class="p">):</span>
    <span class="n">dblock</span> <span class="o">=</span> <span class="n">DataBlock</span><span class="p">(</span><span class="n">blocks</span>     <span class="o">=</span> <span class="p">(</span><span class="n">ImageBlock</span><span class="p">,</span> <span class="n">CategoryBlock</span><span class="p">),</span>
                       <span class="n">get_items</span>  <span class="o">=</span> <span class="n">get_image_files</span><span class="p">,</span>
                       <span class="n">get_y</span>      <span class="o">=</span> <span class="n">parent_label</span><span class="p">,</span>
                       <span class="n">splitter</span>   <span class="o">=</span> <span class="n">GrandparentSplitter</span><span class="p">(</span><span class="n">valid_name</span><span class="o">=</span><span class="s">'val'</span><span class="p">),</span>
                       <span class="n">item_tfms</span>  <span class="o">=</span> <span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">min_scale</span><span class="o">=</span><span class="mf">0.75</span><span class="p">),</span>
                       <span class="n">batch_tfms</span> <span class="o">=</span> <span class="n">batch_tfms</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">dblock</span><span class="p">.</span><span class="n">dataloaders</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="n">bs</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">cuda</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span> <span class="o">=</span> <span class="n">get_dls</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">224</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span><span class="p">.</span><span class="n">show_batch</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/08/pneumonia-chest-xray-classification/output_10_0.png" style="zoom: 70%;" />
</center>

<h3 id="model">Model</h3>

<p>As mentioned, we’ll be using a pretrained model called <strong>EfficientNetB3A</strong>. The few blocks of code below are from Zachary Mueller’s <strong>Practical-Deep-Learning-for-Coders-2.0</strong> notebook tutorials. In particular, his <a href="https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/blob/master/Computer%20Vision/05_EfficientNet_and_Custom_Weights.ipynb">notebook</a> titled <strong>05 EfficientNet and Custom Pretrained Models</strong> showed how to create a timm body, load pretrained weights, create a model head accordingly, and combine the two together.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">create_timm_body</span><span class="p">(</span><span class="n">arch</span><span class="p">:</span><span class="nb">str</span><span class="p">,</span> <span class="n">pretrained</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">cut</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">create_model</span><span class="p">(</span><span class="n">arch</span><span class="p">,</span> <span class="n">pretrained</span><span class="o">=</span><span class="n">pretrained</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">cut</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">ll</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">children</span><span class="p">()))</span>
        <span class="n">cut</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">o</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="n">ll</span><span class="p">)</span> <span class="k">if</span> <span class="n">has_pool_type</span><span class="p">(</span><span class="n">o</span><span class="p">))</span>
    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cut</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">children</span><span class="p">())[:</span><span class="n">cut</span><span class="p">])</span>
    <span class="k">elif</span> <span class="nb">callable</span><span class="p">(</span><span class="n">cut</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">cut</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">raise</span> <span class="n">NamedError</span><span class="p">(</span><span class="s">"cut must be either integer or function"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">body</span> <span class="o">=</span> <span class="n">create_timm_body</span><span class="p">(</span><span class="s">'efficientnet_b3a'</span><span class="p">,</span> <span class="n">pretrained</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">nf</span> <span class="o">=</span> <span class="n">num_features_model</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">body</span><span class="p">.</span><span class="n">children</span><span class="p">()))</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span><span class="p">)</span>
<span class="n">head</span> <span class="o">=</span> <span class="n">create_head</span><span class="p">(</span><span class="n">nf</span><span class="p">,</span> <span class="n">dls</span><span class="p">.</span><span class="n">c</span><span class="p">)</span>
</code></pre></div></div>

<p>After creating the model here, we’ll apply a Kaiming Normal initialization to the second half of the model. Kaiming He’s normalization technique is introduced on this <a href="https://arxiv.org/abs/1502.01852">paper</a>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">body</span><span class="p">,</span> <span class="n">head</span><span class="p">)</span>
<span class="n">apply_init</span><span class="p">(</span><span class="n">model</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">nn</span><span class="p">.</span><span class="n">init</span><span class="p">.</span><span class="n">kaiming_normal_</span><span class="p">)</span>
</code></pre></div></div>

<p>We’ll use <code class="language-plaintext highlighter-rouge">LabelSmoothingCrossEntropy</code> and <code class="language-plaintext highlighter-rouge">MixUp</code> callback as suggested in fastbook. Both the loss function and callback may contribute to improving the model’s accuracy. You can find papers introducing Label Smoothing <a href="https://arxiv.org/abs/1512.00567">here</a> and Mixup <a href="https://arxiv.org/abs/1710.09412">here</a>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span> <span class="o">=</span> <span class="n">Learner</span><span class="p">(</span><span class="n">dls</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_func</span><span class="o">=</span><span class="n">LabelSmoothingCrossEntropy</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">=</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">cbs</span><span class="o">=</span><span class="n">MixUp</span><span class="p">())</span>
</code></pre></div></div>

<p>Since the model takes up a lot of GPU memory, using one GPU wasn’t enough. Luckily I have two NVIDIA GeForce GTX 980M, so I split the computation to both of them using PyTorch’s <code class="language-plaintext highlighter-rouge">DataParallel</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">device_count</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Let's use"</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">device_count</span><span class="p">(),</span> <span class="s">"GPUs!"</span><span class="p">)</span>
    <span class="n">learn</span><span class="p">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">learn</span><span class="p">.</span><span class="n">model</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Let's use 2 GPUs!
</code></pre></div></div>

<h3 id="training-model">Training Model</h3>

<p>Once everything has been setup, we can find a good learning rate to train the model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">lr_find</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>c:\users\wilso\appdata\local\programs\python\python38\lib\site-packages\torch\cuda\nccl.py:14: UserWarning: PyTorch is not compiled with NCCL support
  warnings.warn('PyTorch is not compiled with NCCL support')





SuggestedLRs(lr_min=0.006918309628963471, lr_steep=9.120108734350652e-05)
</code></pre></div></div>

<center>
<img src="/images/2020/08/pneumonia-chest-xray-classification/output_22_3.png" style="zoom: 70%;" />
</center>

<p>Here we’ll train the model for 10 epochs with <a href="https://arxiv.org/abs/1708.07120">one-cycle policy</a>, add a <code class="language-plaintext highlighter-rouge">0.1</code> weight decay.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mf">6e-3</span><span class="p">,</span> <span class="n">wd</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">cbs</span><span class="o">=</span><span class="n">SaveModelCallback</span><span class="p">(</span><span class="n">fname</span><span class="o">=</span><span class="s">'best-val-loss'</span><span class="p">))</span>
<span class="n">learn</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">'efficientnetb3a-1'</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>accuracy</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>0.762553</td>
      <td>1.285270</td>
      <td>0.500000</td>
      <td>02:28</td>
    </tr>
    <tr>
      <td>1</td>
      <td>0.560403</td>
      <td>1.788918</td>
      <td>0.500000</td>
      <td>02:27</td>
    </tr>
    <tr>
      <td>2</td>
      <td>0.497399</td>
      <td>0.727971</td>
      <td>0.562500</td>
      <td>02:26</td>
    </tr>
    <tr>
      <td>3</td>
      <td>0.460750</td>
      <td>0.842557</td>
      <td>0.625000</td>
      <td>02:25</td>
    </tr>
    <tr>
      <td>4</td>
      <td>0.532170</td>
      <td>8.171339</td>
      <td>0.625000</td>
      <td>02:25</td>
    </tr>
    <tr>
      <td>5</td>
      <td>0.493482</td>
      <td>2.005133</td>
      <td>0.687500</td>
      <td>02:30</td>
    </tr>
    <tr>
      <td>6</td>
      <td>0.435214</td>
      <td>0.956962</td>
      <td>0.562500</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>7</td>
      <td>0.397469</td>
      <td>0.727003</td>
      <td>0.562500</td>
      <td>03:12</td>
    </tr>
    <tr>
      <td>8</td>
      <td>0.377309</td>
      <td>0.713967</td>
      <td>0.625000</td>
      <td>02:27</td>
    </tr>
    <tr>
      <td>9</td>
      <td>0.380401</td>
      <td>0.636497</td>
      <td>0.625000</td>
      <td>02:25</td>
    </tr>
  </tbody>
</table>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Better model found at epoch 0 with valid_loss value: 1.2852699756622314.
Better model found at epoch 2 with valid_loss value: 0.7279710173606873.
Better model found at epoch 7 with valid_loss value: 0.7270027995109558.
Better model found at epoch 8 with valid_loss value: 0.7139670252799988.
Better model found at epoch 9 with valid_loss value: 0.6364966630935669.





Path('models/efficientnetb3a-1.pth')
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">recorder</span><span class="p">.</span><span class="n">plot_loss</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/08/pneumonia-chest-xray-classification/output_25_0.png" style="zoom: 70%;" />
</center>

<h3 id="testing-model">Testing Model</h3>

<p>As mentioned, the validation dataset is too small to measure our model’s performance. Fortunately, the dataset gave a large enough test dataset which we’ll be using.</p>

<h4 id="load-test-data">Load Test Data</h4>

<p>The method I’ll be using here and for the rest of this notebook is a patchy solution. Specifically, I’ll create a test dataloader and replace the old validation dataset with it.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_test_dls</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">test_folder</span><span class="p">):</span>
    <span class="n">dblock</span> <span class="o">=</span> <span class="n">DataBlock</span><span class="p">(</span><span class="n">blocks</span>     <span class="o">=</span> <span class="p">(</span><span class="n">ImageBlock</span><span class="p">,</span> <span class="n">CategoryBlock</span><span class="p">),</span>
                       <span class="n">get_items</span>  <span class="o">=</span> <span class="n">get_image_files</span><span class="p">,</span>
                       <span class="n">get_y</span>      <span class="o">=</span> <span class="n">parent_label</span><span class="p">,</span>
                       <span class="n">splitter</span>   <span class="o">=</span> <span class="n">GrandparentSplitter</span><span class="p">(</span><span class="n">valid_name</span><span class="o">=</span><span class="n">test_folder</span><span class="p">),</span>
                       <span class="n">item_tfms</span>  <span class="o">=</span> <span class="n">Resize</span><span class="p">(</span><span class="n">size</span><span class="p">),</span>
                       <span class="n">batch_tfms</span> <span class="o">=</span> <span class="n">batch_tfms</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">dblock</span><span class="p">.</span><span class="n">dataloaders</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="n">bs</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">cuda</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test_dl</span> <span class="o">=</span> <span class="n">get_test_dls</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="s">'test'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">dls</span> <span class="o">=</span> <span class="n">test_dl</span>
</code></pre></div></div>

<h4 id="test-accuracy">Test Accuracy</h4>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">preds</span><span class="p">,</span> <span class="n">targs</span> <span class="o">=</span> <span class="n">learn</span><span class="p">.</span><span class="n">get_preds</span><span class="p">()</span>
<span class="n">accuracy</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">targs</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor(0.9231)
</code></pre></div></div>

<h3 id="analyze-results">Analyze Results</h3>

<p>The model achieved a 92% accuracy for the test data. However, using accuracy as a measure of performance in an unbalanced dataset is unwise. If say we have 95 normal chest images and 5 pneumonia-infected ones, freely guessing 100 of them to be normal would still output a high 95% accuracy. Hence, <strong>Precision</strong> and <strong>Recall</strong> is a better metric to use in this case.</p>

<p>According to the Scikit Learn <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score">docs</a>, precision is <em>intuitively the ability of the classifier not to label as positive a sample that is negative</em>. Whereas recall is <em>intuitively the ability of the classifier to find all the positive samples.</em></p>

<p>We can plot the results of the test predictions and visualize using a confusion matrix. In fact, plotting such diagram is available in the fast.ai library. However, for some reason I couldn’t get it to work in this new update. Thus, I decided to simply copy the actual fast.ai <code class="language-plaintext highlighter-rouge">interpret</code> <a href="https://github.com/fastai/fastai/blob/master/fastai/interpret.py#L51">code</a> and modify it to fix the issue.</p>

<p>I found that the <code class="language-plaintext highlighter-rouge">confusion_matrix</code> code broke the plotting process which is dependent on it. To fix the issue, I’ve replaced the confusion matrix with Scitkit Learn’s. Lastly, I specified the function to also print the recall and precision metrics, both of which are from Scikit Learn.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_confusion_matrix</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">vocab</span><span class="p">):</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">cm</span> <span class="o">=</span> <span class="n">confusion_matrix</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span>

    <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">60</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">cm</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s">'nearest'</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"Blues"</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"Confusion Matrix"</span><span class="p">)</span>
    <span class="n">tick_marks</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">vocab</span><span class="p">))</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">vocab</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">90</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">yticks</span><span class="p">(</span><span class="n">tick_marks</span><span class="p">,</span> <span class="n">vocab</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">thresh</span> <span class="o">=</span> <span class="n">cm</span><span class="p">.</span><span class="nb">max</span><span class="p">()</span> <span class="o">/</span> <span class="mf">2.</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">itertools</span><span class="p">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">range</span><span class="p">(</span><span class="n">cm</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])):</span>
        <span class="n">coeff</span> <span class="o">=</span> <span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span><span class="si">}</span><span class="s">'</span>
        <span class="n">plt</span><span class="p">.</span><span class="n">text</span><span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">coeff</span><span class="p">,</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s">"center"</span><span class="p">,</span> <span class="n">verticalalignment</span><span class="o">=</span><span class="s">"center"</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">"white"</span> <span class="k">if</span> <span class="n">cm</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">thresh</span> <span class="k">else</span> <span class="s">"black"</span><span class="p">)</span>

    <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">gca</span><span class="p">()</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span><span class="o">-</span><span class="p">.</span><span class="mi">5</span><span class="p">,</span><span class="o">-</span><span class="p">.</span><span class="mi">5</span><span class="p">)</span>

    <span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Actual'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Predicted'</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">grid</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Precision: </span><span class="si">{</span><span class="n">precision_score</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Recall: </span><span class="si">{</span><span class="n">recall_score</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_confusion_matrix</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">targs</span><span class="p">,</span> <span class="n">dls</span><span class="p">.</span><span class="n">vocab</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Precision: 0.896
Recall: 0.992
</code></pre></div></div>

<center>
<img src="/images/2020/08/pneumonia-chest-xray-classification/output_35_1.png" style="zoom: 70%;" />
</center>

<p>The model achieved 89% precision and 99% recall!</p>

<h2 id="closing-remarks">Closing Remarks</h2>

<p>Despite all of the issues and troubles I’ve stumbled upon during the project, I’ve learned to be more flexible in utilizing the tools available. I’ve also attempted this task over a year ago as a beginner in deep learning. To my surprise, I’ve actually solved it previously using a VGG19 pretrained model in Tensorflow/Keras and attained quite a satisfying result as well.</p>

<p>In any case, this mini project taught me tons and am excited to learn even more deep learning related topics. Hope you’ve learned something!</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Convolutional Neural Network" /><summary type="html"><![CDATA[The dataset used for this task if from a Kaggle dataset by Paul Mooney. It consists of two kinds of chest x-rays, those infected by pneumonia, and the other being normal. Our main goal is to distinguish which chest corresponds to pneumonia-infected ones and which aren’t. Note that the dataset is highly imbalanced, like many medical image dataset are.]]></summary></entry><entry><title type="html">Text Generation using minGPT and fast.ai</title><link href="https://wilsonwongso.dev/posts/2020/08/text-generation-with-mingpt-fastai/" rel="alternate" type="text/html" title="Text Generation using minGPT and fast.ai" /><published>2020-08-24T00:00:00+10:00</published><updated>2020-08-24T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/08/text-generation-with-mingpt-fastai</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/08/text-generation-with-mingpt-fastai/"><![CDATA[<p>Andrej Karpathy, Tesla’s AI Director released minGPT, a mini version to OpenAI’s GPT. Normally a GPT would have billions of parameters and would take hours to train. Karpathy’s approach is to provide a smaller version of GPT, hence the name minGPT.</p>

<h3 id="mingpt--fastai">minGPT + fast.ai</h3>

<p>Fast.ai has just released its version 2.0. This version is a total rewrite to its precursor. It works with other various PyTorch libraries and could also integrate with purely PyTorch code. Morgan Mcguire (<a href="https://github.com/morganmcg1">morganmcg1</a> on Github) shared a code whereby the author incorporated Karpathy’s minGPT with fast.ai. <strong>It is from Mcguire’s code from which this project works upon.</strong> Credits to Morgan Mcguire for the code. I do not own the code, I simply changed minor bits (data, hyperparameters) in the overall code.</p>

<h3 id="yabes-elia--zilbest">Yabes Elia &amp; Zilbest</h3>

<p>Yabes Elia is an editor for esports article. He was and is my current editor. Before that, he used to blog in his own page, Zilbest.com. The blog focused on several topics, including Philosophy, Romance, and Psychology. After reading his blog posts, I got the idea to train a language model upon his writing. I thought it would be interesting to let a deep learning model learn a person’s style of language. <strong>Credits to <em>mas</em> Yabes Elia, for allowing me to use his blog post at Zilbest.com as data source.</strong></p>

<h2 id="code">Code</h2>

<p>The following code is based on morganmcg1’s “<a href="https://gist.github.com/morganmcg1/b2a26e213482d3355a3d3a64c91e94ac">A Quick Demo of Andrej Karpathy’s minGPT Play Char Demo</a>”. Only fragments of the important blocks of code were included.</p>

<h3 id="loading-data">Loading Data</h3>

<p>The data is simply a .txt file filled with Yabes Elia’s articles on Zilbest. I’ve uploaded the .txt file to my Google Drive, loaded it, and showed the first 100 items.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">raw_text</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="n">drive_path</span><span class="o">/</span><span class="s">'yabes-elia.txt'</span><span class="p">,</span> <span class="s">'r'</span><span class="p">).</span><span class="n">read</span><span class="p">()</span>
<span class="n">raw_text</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">100</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'“You will never be happy if you continue to search for what happiness consists of. You will never li'
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">len</span><span class="p">(</span><span class="n">raw_text</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>227914
</code></pre></div></div>

<h3 id="transforms">Transforms</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">CharTransform</span><span class="p">(</span><span class="n">Transform</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">block_size</span><span class="p">):</span>
        <span class="n">chars</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">data</span><span class="p">))</span>
        <span class="n">data_size</span><span class="p">,</span> <span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">chars</span><span class="p">)</span>
        <span class="k">print</span><span class="p">(</span><span class="s">'data has %d characters, %d unique.'</span> <span class="o">%</span> <span class="p">(</span><span class="n">data_size</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">))</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">stoi</span> <span class="o">=</span> <span class="p">{</span> <span class="n">ch</span><span class="p">:</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">ch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chars</span><span class="p">)</span> <span class="p">}</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">itos</span> <span class="o">=</span> <span class="p">{</span> <span class="n">i</span><span class="p">:</span><span class="n">ch</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">ch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chars</span><span class="p">)</span> <span class="p">}</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">block_size</span> <span class="o">=</span> <span class="n">block_size</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">vocab_size</span> <span class="o">=</span> <span class="n">vocab_size</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_sequences</span> <span class="o">=</span> <span class="n">math</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">encodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">o</span><span class="p">):</span>
        <span class="n">i</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data</span><span class="p">)</span> <span class="o">-</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
        <span class="n">chunk</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="bp">self</span><span class="p">.</span><span class="n">block_size</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">dix</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">chunk</span><span class="p">]</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">dix</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">decodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">o</span><span class="p">):</span>
        <span class="n">t</span> <span class="o">=</span> <span class="s">''</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="bp">self</span><span class="p">.</span><span class="n">itos</span><span class="p">[</span><span class="n">s</span><span class="p">.</span><span class="n">item</span><span class="p">()]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">o</span><span class="p">])</span>
        <span class="k">return</span> <span class="n">TitledStr</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sl</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">block_size</span> <span class="o">=</span> <span class="n">sl</span>
<span class="n">n_samples</span> <span class="o">=</span> <span class="n">math</span><span class="p">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">raw_text</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>

<span class="n">tls</span> <span class="o">=</span> <span class="n">TfmdLists</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">)),</span> <span class="n">tfms</span><span class="o">=</span><span class="p">[</span><span class="n">CharTransform</span><span class="p">(</span><span class="n">raw_text</span><span class="p">,</span> <span class="mi">128</span><span class="p">)],</span> <span class="n">split_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dl_type</span><span class="o">=</span><span class="n">LMDataLoader</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>data has 227914 characters, 93 unique.
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">show_at</span><span class="p">(</span><span class="n">tls</span><span class="p">.</span><span class="n">train</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Faktanya, mengubah sejarah dunia itu tidak akan pernah semudah membalikkan telapak tangan, atau dalam hal ini, menuliskan komenta
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bs</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">dls</span> <span class="o">=</span> <span class="n">tls</span><span class="p">.</span><span class="n">dataloaders</span><span class="p">(</span><span class="n">bs</span><span class="o">=</span><span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">sl</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span><span class="p">.</span><span class="n">show_batch</span><span class="p">(</span><span class="n">max_n</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>text</th>
      <th>text_</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>ibadi? Well, saya orang praktis. Saat saya masih jadi Managing Editor PC Gamer Indonesia, saya tentu lebih pro dengan MOBA di PC</td>
      <td>badi? Well, saya orang praktis. Saat saya masih jadi Managing Editor PC Gamer Indonesia, saya tentu lebih pro dengan MOBA di PC.</td>
    </tr>
    <tr>
      <th>1</th>
      <td>al tidur sianah yang bisa memberikan jawaban jujur tentang siapa kita, bukan kuis-kuis di dunia maya yang tidak jelas algoritman</td>
      <td>l tidur sianah yang bisa memberikan jawaban jujur tentang siapa kita, bukan kuis-kuis di dunia maya yang tidak jelas algoritmany</td>
    </tr>
  </tbody>
</table>

<h3 id="dropouput-callback">DropOuput Callback</h3>

<p>Replacing fast.ai Learner’s <code class="language-plaintext highlighter-rouge">self.learn.pred</code> by its first element.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">DropOutput</span><span class="p">(</span><span class="n">Callback</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">after_pred</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">learn</span><span class="p">.</span><span class="n">pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">pred</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<h3 id="model-mingpt">Model: minGPT</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">mconf</span> <span class="o">=</span> <span class="n">GPTConfig</span><span class="p">(</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">sl</span><span class="p">,</span> <span class="n">n_layer</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_head</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_embd</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">GPT</span><span class="p">(</span><span class="n">mconf</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span> <span class="o">=</span> <span class="n">Learner</span><span class="p">(</span><span class="n">dls</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">loss_func</span><span class="o">=</span><span class="n">CrossEntropyLossFlat</span><span class="p">(),</span> <span class="n">opt_func</span><span class="o">=</span><span class="n">partial</span><span class="p">(</span><span class="n">Adam</span><span class="p">,</span> <span class="n">sqr_mom</span><span class="o">=</span><span class="mf">0.95</span><span class="p">,</span> <span class="n">wd</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span>
                <span class="n">cbs</span><span class="o">=</span><span class="p">[</span><span class="n">DropOutput</span><span class="p">])</span>
</code></pre></div></div>

<h3 id="training-model">Training Model</h3>

<p>As per fast.ai practice, we let the Learner find the ideal Learning Rate, in our case we got about $0.003 \approx 3e-3$.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">lr_find</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")





SuggestedLRs(lr_min=0.0033113110810518267, lr_steep=2.0892961401841603e-05)
</code></pre></div></div>

<center>
<img src="/images/2020/08/text-generation-with-mingpt-fastai/output_17_3.png" style="zoom: 70%;" />
</center>

<p>With that, we proceeded to training the model for 100 epochs and the LR which we’ve found optimal.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mf">3e-3</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>3.254104</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>1</td>
      <td>3.155325</td>
      <td>None</td>
      <td>00:15</td>
    </tr>
    <tr>
      <td>2</td>
      <td>3.099205</td>
      <td>None</td>
      <td>00:15</td>
    </tr>
    <tr>
      <td>3</td>
      <td>3.032215</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>4</td>
      <td>2.936109</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>5</td>
      <td>2.849230</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>6</td>
      <td>2.779433</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>7</td>
      <td>2.719887</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>8</td>
      <td>2.667851</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>9</td>
      <td>2.624543</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>10</td>
      <td>2.585148</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>11</td>
      <td>2.552182</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>12</td>
      <td>2.523161</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>13</td>
      <td>2.498031</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>14</td>
      <td>2.476660</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>15</td>
      <td>2.455955</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>16</td>
      <td>2.441366</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>17</td>
      <td>2.427677</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>18</td>
      <td>2.414104</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>19</td>
      <td>2.397461</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>20</td>
      <td>2.383328</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>21</td>
      <td>2.368615</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>22</td>
      <td>2.352587</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>23</td>
      <td>2.335341</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>24</td>
      <td>2.323342</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>25</td>
      <td>2.305508</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>26</td>
      <td>2.286461</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>27</td>
      <td>2.262887</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>28</td>
      <td>2.237531</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>29</td>
      <td>2.211838</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>30</td>
      <td>2.186196</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>31</td>
      <td>2.156658</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>32</td>
      <td>2.128527</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>33</td>
      <td>2.104312</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>34</td>
      <td>2.074495</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>35</td>
      <td>2.046017</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>36</td>
      <td>2.018104</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>37</td>
      <td>1.990814</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>38</td>
      <td>1.963953</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>39</td>
      <td>1.938050</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>40</td>
      <td>1.910195</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>41</td>
      <td>1.882767</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>42</td>
      <td>1.859885</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>43</td>
      <td>1.833001</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>44</td>
      <td>1.805273</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>45</td>
      <td>1.777778</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>46</td>
      <td>1.749810</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>47</td>
      <td>1.721224</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>48</td>
      <td>1.694282</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>49</td>
      <td>1.668665</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>50</td>
      <td>1.641540</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>51</td>
      <td>1.614098</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>52</td>
      <td>1.587708</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>53</td>
      <td>1.560743</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>54</td>
      <td>1.534708</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>55</td>
      <td>1.510127</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>56</td>
      <td>1.486278</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>57</td>
      <td>1.461563</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>58</td>
      <td>1.438166</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>59</td>
      <td>1.415540</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>60</td>
      <td>1.392969</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>61</td>
      <td>1.371182</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>62</td>
      <td>1.351205</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>63</td>
      <td>1.331026</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>64</td>
      <td>1.311882</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>65</td>
      <td>1.293381</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>66</td>
      <td>1.274096</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>67</td>
      <td>1.256531</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>68</td>
      <td>1.237806</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>69</td>
      <td>1.221424</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>70</td>
      <td>1.204520</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>71</td>
      <td>1.189105</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>72</td>
      <td>1.172827</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>73</td>
      <td>1.156720</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>74</td>
      <td>1.140753</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>75</td>
      <td>1.125648</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>76</td>
      <td>1.111875</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>77</td>
      <td>1.097298</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>78</td>
      <td>1.083305</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>79</td>
      <td>1.069097</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>80</td>
      <td>1.056546</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>81</td>
      <td>1.044658</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>82</td>
      <td>1.033119</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>83</td>
      <td>1.021210</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>84</td>
      <td>1.009997</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>85</td>
      <td>0.999994</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>86</td>
      <td>0.989661</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>87</td>
      <td>0.979982</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>88</td>
      <td>0.970661</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>89</td>
      <td>0.961383</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>90</td>
      <td>0.953398</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>91</td>
      <td>0.946190</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>92</td>
      <td>0.939140</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>93</td>
      <td>0.932855</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>94</td>
      <td>0.926477</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>95</td>
      <td>0.921115</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>96</td>
      <td>0.915792</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>97</td>
      <td>0.911426</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>98</td>
      <td>0.907237</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
    <tr>
      <td>99</td>
      <td>0.904241</td>
      <td>None</td>
      <td>00:14</td>
    </tr>
  </tbody>
</table>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>/usr/local/lib/python3.6/dist-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">recorder</span><span class="p">.</span><span class="n">plot_loss</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/08/text-generation-with-mingpt-fastai/output_20_0.png" style="zoom: 70%;" />
</center>

<h3 id="testing-model">Testing Model</h3>

<p>After training, we can feed the model a contextual phrase/sentence and let it generate the rest of the text. We sampled the model’s result and let it predict the next 2000 steps.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">minGPT.mingpt.utils</span> <span class="kn">import</span> <span class="n">sample</span>
</code></pre></div></div>

<h4 id="context-1-karena-itu">Context 1: “Karena itu,”</h4>

<p>In English: “Therefore,”.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">minGPT.mingpt.utils</span> <span class="kn">import</span> <span class="n">sample</span>

<span class="n">context</span> <span class="o">=</span> <span class="s">"Karena itu,"</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">context</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)[</span><span class="bp">None</span><span class="p">,...].</span><span class="n">to</span><span class="p">(</span><span class="n">dls</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mi">2000</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">sample</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">top_k</span><span class="o">=</span><span class="mi">5</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">completion</span> <span class="o">=</span> <span class="s">''</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">itos</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">i</span><span class="p">)]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">y</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">completion</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Karena itu, sistem gurus muluk disemungkinan di sini adalah kemuncegan banyak karakteristik pahlawan saya yang berbeda atas hidup kita bisa mendapatkan kehidupan hasil, keturunan, ataupun hilang rasa, satu hal yang bisa dijelaskan dengan kata-kata Anda tadi, kemungkinan besar, Anda terkritik dengan sendiri. Jika Anda tidak tahu kalah itu jauh lebih sulit dan lebih tertarik meraih pada konfirmasi dengan keatan kita, ataupun hal-hal lainnya karena sebenarnya ada satu hal yang menulis.

Saya pribadi, jika saya sangat merasa menyenangkan untuk komenangkap ini mungkin bisa berpikir diterakhir yang membutuhkan bahwa pribadi jika Anda tidak ada yang suka daripada satu kawan saya sudah berpasangan dalam menghubungan sesuatu Anda tidak suka dengan sekalipun, kategori saya yakin Anda juga masih sering berpikir saya bisa memaknakan para pembela atau malah sarat dengan personal tentang skeptisisme.

Saya juga tidak akan berubah waktu.

Namun, kebalikan dari kreatif seperti kita drumah, dan perspektif seorang istri ini.

Misalnya saja seperti ini, saya tidak pernah menghantarkan penutup artikel ini. Setidaknya, saya memang sudah bekerja dari sudut pandang menghasilkan kebetulanan saja, selanjutnya seperti ini. Satu hal yang membuat saya pernah menuliskan saya bekerja keras dan kembali buku sosial dan sebagai kuburan tahun berapa buku saya adalah sebagian besar tadi sebenarnya sudah tidur dari pasangan.


Di dunia riil, saya pribadi memiliki alam hidup yang berbeda dari segi yang saya pernah berada di depan kita, dan keluarga kita mau menikmati keseluarga ketimbang dua memahami segera pandai bagaimana relevan.

Dari sejumlah satu komunitas adalah seperti grafis di bawah ini untuk diri sendiri. Misalnya, satu hal yang sama-sama sekali, dan kuis-kuis di jaman sekarang ini, ada banyak h kawan-kawan saya yang memegang tidak berbagai semua saya kira semua pasti tidak menyebutkan berakhir demikian? Kita juga tidak akan lelahnya selalu berada dengan satu cara yang sama (saya). Namun juga saya tahu
</code></pre></div></div>

<h4 id="context-2-filosofi-saya-adalah">Context 2: “Filosofi saya adalah”</h4>

<p>In English: “My philosophy is”.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">context</span> <span class="o">=</span> <span class="s">"Filosofi saya adalah"</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">context</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)[</span><span class="bp">None</span><span class="p">,...].</span><span class="n">to</span><span class="p">(</span><span class="n">dls</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mi">2000</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">sample</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">top_k</span><span class="o">=</span><span class="mi">5</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">completion</span> <span class="o">=</span> <span class="s">''</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">itos</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">i</span><span class="p">)]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">y</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">completion</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Filosofi saya adalah ketidakpastian dan sang pernah berhasil mengolahnya kebanyakan soal image yang mengatakan bahwa saya menghabiskan waktu untuk berubah dambaan hati Anda, dan sebelum tulisan ini juga sangat baik itu memperti sebelumnya.

Saya kira semua suka saya dibelikan banyak mobile.

Ditambah lagi, atau proses berpikir lebih jauh masing-masing. Pasalnya, merasa tidak memuaskan keras untuk memperkayakan diri dengan kepentingan yang saya tadi, seperti pemilik solusi yang lebih beruntung ketimbang mendengarkan gilita semua itu tidak aktif dan marah terhadap kebebahagiaan di kondisi lainnya sebelumnya.

Akhirnya, saya pribadi juga melihat ketika mana yang semua hal yang bisa Anda tidak akan mengeluhi kegagalanan, saya kira semua bisa sampai ke titik ini – tulisan saya ditujukan di sini adalah kesatuan yang bisa kita ahadapi di kepentingan industri ini membutuhkan sebagian tadi berpikir – karena mencerita jadi sebuah pasangan atau berpikir lebih jauh.

Maksud saya terhasal menghadapi sesuai dengan keputusan yang saya rasakan. Setiap kita pasti punya keinginan bisa jadi profesional, berpikir berbasiskan bisa jadi salah satu cenderung untuk mencari tahu alias karena sifat kepintaran tersebut di sana.

Misalnya, terasal dari satu tim/ gumen, pasti pacaran tadi pacarnya, pernah saya bisa mengajak keluarga dalam membela kerap bersisa dapat membebaskan apa yang kita percayai adalah kesuksesan dan berbagi berpikir internet dan menggelitik. Tutup jawaban memang mudah mendorong untuk mencari kesadar dan satu pasangan Anda, selama 15 tahun yang berbeda dari kondisi yang lainnya.

Namun demikian, kecenderungan untuk melarang lebih besar ketimbang harus kita sedih memiliki personalitas kita bisa jadi tolak ukur dan kepintaran seseorang seperti seperti bahkan sebuah soal game, seperti seperti yang seperti apakah yang bisa semesta dan mengurus rumah tangga.

Saya adalah rekaman sepenuhnya dengan tangan Anda, Anda akan pernah mendengar adalah orang-orang yang berpikiran – siedak akan berarti saya
</code></pre></div></div>

<h4 id="context-3-bagi-saya-hidup">Context 3: “Bagi saya, hidup”</h4>

<p>In English: “For me, life”.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">context</span> <span class="o">=</span> <span class="s">"Bagi saya, hidup"</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">context</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)[</span><span class="bp">None</span><span class="p">,...].</span><span class="n">to</span><span class="p">(</span><span class="n">dls</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mi">2000</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">sample</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">top_k</span><span class="o">=</span><span class="mi">5</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">completion</span> <span class="o">=</span> <span class="s">''</span><span class="p">.</span><span class="n">join</span><span class="p">([</span><span class="n">dls</span><span class="p">.</span><span class="n">char_transform</span><span class="p">.</span><span class="n">itos</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">i</span><span class="p">)]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">y</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">completion</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Bagi saya, hidup itu sebenarnya manusia itu bisa berubah-ubah dan kesamaan Anda…

So, saya pribadi mencari saya, kemungkinan besar, Anda juga akan memuaskan keruntungan keinginan sosial dan memproses bebagai sebuah kehidupannya seperti karena pil berargumen bahwa pada pengecualian. Jika Anda bisa membaca itu saja yang sebenarnya tak punya kesedihan menuntut dunia. Dari perspektif seorang berbeda jadi berusaha dengan masalah dunia nyata, termasuk sebagian dan berbeda. Sebelum kita, meminilah kepuasannya sesama seperti ini adalah satu hal yang pasti pernah merasakan hal yang sama, termasuk dalam hidup itu terjadi.

Siapakah yang saya tidak mengalahkan pusat ini seringkali disadari. Dalam perspektif yang membuat Anda pernah dengan percaya setiap orang idealisme itu biasanya merupakannya.

Namun setiap kita punya keingintahuan yang sama sama seperti ini, saya percaya bahwa adalah rekan besar tadi setiap orang yang paling suka merasa melihat hal tersebut tertarik untuk menjadi bagian dari kebencian Andscommbias yang bernada dari segi sebenarnya bisa berkurang terus bekerja atau bisa memiliki cerita semakin banyak orang suami itu terjadi ketika kita masih mencari kesalahan prestasi atau tidak berawal dengan argumen yang namanya pendapat yang berbeda, dengan sedikit juga akan lakukan dari segudang seksual. Tokoh karena kita punya sudah berpasangan tahun lalu bagaimana jika kita berada di misalnya. Namun, kenyataannya, banyak orang tua, anak ‘multiplaya yang digunakan oleh orang lain – meski tidak ada yang lainnya.

Sebenarnya apa? Kegalauan, keyakinan Anda tujuan selalu mengerti pasangan Anda selalu merasa saat mengakui sesuatu di saat Anda.

Memang, saya sudah menyarankan pertama atau sistem berbeda di sini saat ini.

Akhirnya, tidak sama seperti yang saya rasakan. Saya kira saya tahu bahwa kita bisa saja memiliki hasrat tersebut karena tidak akan pernah terlalu memang negatif lainnya.

Misalnya saja seperti ini, saya kira saya juga tidak mau berhadapan dengan soal lainnya.

Sayangnya, m
</code></pre></div></div>

<h2 id="closing-remarks">Closing Remarks</h2>

<h3 id="conclusion">Conclusion</h3>

<p>To sum up, here are several of my remarks:</p>

<ul>
  <li>McGuire showed how easy it is to integrate fast.ai with PyTorch models and libraries.</li>
  <li>Fast.ai abstracts the need to dive into repetitive task of creating a Trainer for the model, learning rate scheduling, etc.</li>
  <li>Karpathy’s minGPT is very versatile. Despite having much less parameters to OpenAI’s GPT, it still showed good results.</li>
  <li>Although some of the sentences pretty much didn’t have proper grammar, it’s still interesting to let the model write text in the style of mas Yabes Elia.</li>
</ul>

<p>I’ve learned a lot by simply modifying McGuire’s code. As a novice in DL, Language Modelling is certainly something new for me. I’m excited to see what DL is capable of doing across applications. I hope you’ve learned something like I did!</p>

<h3 id="credits">Credits</h3>

<ul>
  <li>morganmcg1’s <a href="https://gist.github.com/morganmcg1/b2a26e213482d3355a3d3a64c91e94ac">A Quick Demo of Andrej Karpathy’s minGPT Play Char Demo</a>.</li>
  <li>karpathy’s <a href="https://github.com/karpathy/minGPT">minGPT</a>.</li>
  <li>Yabes Elia’s <a href="https://zilbest.com/">Zilbest</a> blog posts.</li>
</ul>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Transformer" /><summary type="html"><![CDATA[Andrej Karpathy, Tesla’s AI Director released minGPT, a mini version to OpenAI’s GPT. Normally a GPT would have billions of parameters and would take hours to train. Karpathy’s approach is to provide a smaller version of GPT, hence the name minGPT.]]></summary></entry><entry><title type="html">MNIST Classification with Quantum Neural Network</title><link href="https://wilsonwongso.dev/posts/2020/07/mnist-qnn/" rel="alternate" type="text/html" title="MNIST Classification with Quantum Neural Network" /><published>2020-07-14T00:00:00+10:00</published><updated>2020-07-14T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/07/mnist-qnn</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/07/mnist-qnn/"><![CDATA[<p><a href="https://www.tensorflow.org/">Tensorflow</a> is one of the most used deep learning frameworks today, bundled with many features for end-to-end deep learning processes. Recently, they have just announced a new library on top of Tensorflow, called <a href="https://www.tensorflow.org/quantum">Tensorflow Quantum</a>. Tensorflow Quantum integrates with <a href="https://github.com/quantumlib/Cirq">Cirq</a>, which provides quantum computing algorithms, and the two works well to do tasks involving Quantum Machine Learning.</p>

<h3 id="quantum-computer-simulator">Quantum Computer Simulator</h3>

<p>Tensorflow Quantum provides a default <a href="https://github.com/tensorflow/quantum/tree/v0.3.0/tensorflow_quantum/core/qsim">backend Simulator which is written in C++</a>. It is possible, although slower, to run the backend with a Cirq Simulator, or any other backends like a real quantum computer. However, since real quantum computers of today are still very much noisy and sensitive to inference, the QNN is ran on the C++ simulator backend for simplicity. The aim is to experiment with available hybrid quantum-classical algorithms and see the potential of Quantum Machine Learning once fault-tolerant Quantum Computers become available.</p>

<center>
<img src="/images/sycamore-processor.png" style="zoom: 70%;" /><br />
<figcaption><i>Photograph of the Sycamore processor | Erik Lucero</i></figcaption>
</center>

<h3 id="quantum-neural-networks">Quantum Neural Networks</h3>

<p>One of the realization of Quantum Machine Learning is the implementation of a Quantum Neural Network (QNN), which unlike Hybrid Neural Networks discussed in the <a href="https://wilsonwongso.dev/blog/jupyter/code/python/deeplearning/2020/07/13/mnist-qml-qiskit.html">previous blog</a>, is purely ran on a quantum circuit with only quantum gates. It does not combines both classical and quantum neural network layers, and works quite differently from how a classical neural network does - at least for now.</p>

<h3 id="mnist-classification">MNIST Classification</h3>

<p>Again, we’ll be classifying images from the <a href="http://yann.lecun.com/exdb/mnist/">MNIST dataset</a> with the QNN. The following blocks of code were based on a tutorial from Tensorflow Quantum, called <a href="https://www.tensorflow.org/quantum/tutorials/mnist">MNIST classification</a>. The algorithm used is based on a paper by <a href="https://arxiv.org/abs/1802.06002">Farhi et al.</a>, and is a must-see paper to see the concepts and the why’s of the QNN being implemented.</p>

<h2 id="code">Code</h2>

<h3 id="loading-data">Loading Data</h3>

<h4 id="rescaling-images">Rescaling Images</h4>

<p>As mentioned, we’ll be using the MNIST dataset as usual, which is originally 28x28 pixels each. We’ll be rescaling the images from $[0, 255]$ to $[0.0, 1.0]$ range.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">mnist</span><span class="p">.</span><span class="n">load_data</span><span class="p">()</span>
<span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[...,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span><span class="o">/</span><span class="mf">255.0</span><span class="p">,</span> <span class="n">x_test</span><span class="p">[...,</span> <span class="n">np</span><span class="p">.</span><span class="n">newaxis</span><span class="p">]</span><span class="o">/</span><span class="mf">255.0</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Number of original training examples:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Number of original test examples:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Number of original training examples: 60000
Number of original test examples: 10000
</code></pre></div></div>

<p>Since the final “output layer” or the readout qubit in this case is only 1, we will only classify 2 distinct classes: 3s and 6s.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">filter_36</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
    <span class="n">keep</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span> <span class="o">==</span> <span class="mi">3</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span><span class="n">y</span> <span class="o">==</span> <span class="mi">6</span><span class="p">)</span>
    <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">keep</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="n">keep</span><span class="p">]</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">==</span> <span class="mi">3</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">,</span><span class="n">y</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">filter_36</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">filter_36</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Number of filtered training examples:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Number of filtered test examples:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Number of filtered training examples: 12049
Number of filtered test examples: 1968
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">y_train</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>

<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">])</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True





&lt;matplotlib.colorbar.Colorbar at 0x7f453da67898&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-qnn/output_8_2.png" style="zoom: 70%;" />
</center>

<h4 id="downsampling-images">Downsampling Images</h4>

<p>The images are then downsampled to 4x4 pixels each since we’ll only be using 17 qubits, 16 for the images, and 1 as the readout. This does lower down the resolution of the original image to the point of not representing how it looks originally. But due to the limitation of number of qubits simulatable, downsampling to low resolution images is required.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_train_small</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">image</span><span class="p">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">x_test_small</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">image</span><span class="p">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">)).</span><span class="n">numpy</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">y_train</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>

<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">x_train_small</span><span class="p">[</span><span class="mi">0</span><span class="p">,:,:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">vmin</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True





&lt;matplotlib.colorbar.Colorbar at 0x7f453022d7b8&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-qnn/output_11_2.png" style="zoom: 70%;" />
</center>

<h4 id="removing-contradicting-images">Removing Contradicting Images</h4>

<p>Additionally, there are ambiguous labels in our dataset whereby 1 image has more than 1 labels. We’ll remove those contradicting image-label pairs from the dataset.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">remove_contradicting</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
    <span class="n">mapping</span> <span class="o">=</span> <span class="n">collections</span><span class="p">.</span><span class="n">defaultdict</span><span class="p">(</span><span class="nb">set</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">x</span><span class="p">,</span><span class="n">y</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span><span class="n">ys</span><span class="p">):</span>
       <span class="n">mapping</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flatten</span><span class="p">())].</span><span class="n">add</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>

    <span class="n">new_x</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">new_y</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">x</span><span class="p">,</span><span class="n">y</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
      <span class="n">labels</span> <span class="o">=</span> <span class="n">mapping</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">flatten</span><span class="p">())]</span>
      <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
          <span class="n">new_x</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
          <span class="n">new_y</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">labels</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span>
      <span class="k">else</span><span class="p">:</span>
          <span class="k">pass</span>

    <span class="n">num_3</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">mapping</span><span class="p">.</span><span class="n">values</span><span class="p">()</span> <span class="k">if</span> <span class="bp">True</span> <span class="ow">in</span> <span class="n">value</span><span class="p">)</span>
    <span class="n">num_6</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">mapping</span><span class="p">.</span><span class="n">values</span><span class="p">()</span> <span class="k">if</span> <span class="bp">False</span> <span class="ow">in</span> <span class="n">value</span><span class="p">)</span>
    <span class="n">num_both</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">mapping</span><span class="p">.</span><span class="n">values</span><span class="p">()</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)</span>

    <span class="k">print</span><span class="p">(</span><span class="s">"Number of unique images:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">mapping</span><span class="p">.</span><span class="n">values</span><span class="p">()))</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Number of 3s: "</span><span class="p">,</span> <span class="n">num_3</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Number of 6s: "</span><span class="p">,</span> <span class="n">num_6</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Number of contradictory images: "</span><span class="p">,</span> <span class="n">num_both</span><span class="p">)</span>
    <span class="k">print</span><span class="p">()</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Initial number of examples: "</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">xs</span><span class="p">))</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Remaining non-contradictory examples: "</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_x</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">new_x</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">new_y</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_train_nocon</span><span class="p">,</span> <span class="n">y_train_nocon</span> <span class="o">=</span> <span class="n">remove_contradicting</span><span class="p">(</span><span class="n">x_train_small</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Number of unique images: 10387
Number of 3s:  4961
Number of 6s:  5475
Number of contradictory images:  49

Initial number of examples:  12049
Remaining non-contradictory examples:  11520
</code></pre></div></div>

<h4 id="encoding-data-as-quantum-circuits">Encoding Data as Quantum Circuits</h4>

<p>We have to find a way to represent our images as qubits, and the method implemented in the tutorial is pretty straightforward. We set a certain threshold value, in our case 0.5, and if our pixel value is greater than that, we’ll append Cirq’s X-gate, which flips the qubit state from a $0$ to a $1$ (i.e. signifying the existence of a pixel value in a qubit).</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">THRESHOLD</span> <span class="o">=</span> <span class="mf">0.5</span>

<span class="n">x_train_bin</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_train_nocon</span> <span class="o">&gt;</span> <span class="n">THRESHOLD</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">x_test_bin</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_test_small</span> <span class="o">&gt;</span> <span class="n">THRESHOLD</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">convert_to_circuit</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
    <span class="s">"""Encode truncated classical image into quantum datapoint."""</span>
    <span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">ndarray</span><span class="p">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
    <span class="n">qubits</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">.</span><span class="n">rect</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">circuit</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Circuit</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">values</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">value</span><span class="p">:</span>
            <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">X</span><span class="p">(</span><span class="n">qubits</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
    <span class="k">return</span> <span class="n">circuit</span>


<span class="n">x_train_circ</span> <span class="o">=</span> <span class="p">[</span><span class="n">convert_to_circuit</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">x_train_bin</span><span class="p">]</span>
<span class="n">x_test_circ</span> <span class="o">=</span> <span class="p">[</span><span class="n">convert_to_circuit</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">x_test_bin</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">convert_to_circuit</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
    <span class="s">"""Encode truncated classical image into quantum datapoint."""</span>
    <span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">ndarray</span><span class="p">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
    <span class="n">qubits</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">.</span><span class="n">rect</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">circuit</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Circuit</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">values</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">value</span><span class="p">:</span>
            <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">X</span><span class="p">(</span><span class="n">qubits</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
    <span class="k">return</span> <span class="n">circuit</span>


<span class="n">x_train_circ</span> <span class="o">=</span> <span class="p">[</span><span class="n">convert_to_circuit</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">x_train_bin</span><span class="p">]</span>
<span class="n">x_test_circ</span> <span class="o">=</span> <span class="p">[</span><span class="n">convert_to_circuit</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">x_test_bin</span><span class="p">]</span>
</code></pre></div></div>

<p>Let’s see how one of our training data now looks like once encoded into a circuit. Do note that qubits without operations aren’t printed out.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">SVGCircuit</span><span class="p">(</span><span class="n">x_train_circ</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
</code></pre></div></div>

<center>
<img src="/images/train-sample-data.svg" style="zoom: 70%;" /><br />
<figcaption><i>Sample Training Data as Circuit | Tensorflow Quantum</i></figcaption>
</center>

<p>Lastly, in order to enable the usage of the newly created datapoint, we have to convert it from a circuit back into a tensor.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_train_tfcirc</span> <span class="o">=</span> <span class="n">tfq</span><span class="p">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">x_train_circ</span><span class="p">)</span>
<span class="n">x_test_tfcirc</span> <span class="o">=</span> <span class="n">tfq</span><span class="p">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">x_test_circ</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="quantum-neural-network">Quantum Neural Network</h3>

<p>Now that we have encoded our data that is able to flow through a Tensorflow Quantum’s layers, we’ll begin to create our model. The type of QNN which is implemented in the paper utilizes two-qubit gates that <em>connects</em> every data qubit in the circuit to the readout qubit. At the end of the circuit, the expectation of the readout qubit will then be measured as the basis of our model’s classification.</p>

<h4 id="building-circuit-layers">Building Circuit Layers</h4>

<p>Each layer uses $n$ instances of the same gate, with each of the data qubits acting on the readout qubit. The following class adds a layer of that gate to the circuit.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">CircuitLayerBuilder</span><span class="p">():</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_qubits</span><span class="p">,</span> <span class="n">readout</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">data_qubits</span> <span class="o">=</span> <span class="n">data_qubits</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">readout</span> <span class="o">=</span> <span class="n">readout</span>

    <span class="k">def</span> <span class="nf">add_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">circuit</span><span class="p">,</span> <span class="n">gate</span><span class="p">,</span> <span class="n">prefix</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">qubit</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data_qubits</span><span class="p">):</span>
            <span class="n">symbol</span> <span class="o">=</span> <span class="n">sympy</span><span class="p">.</span><span class="n">Symbol</span><span class="p">(</span><span class="n">prefix</span> <span class="o">+</span> <span class="s">'-'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">))</span>
            <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">gate</span><span class="p">(</span><span class="n">qubit</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">readout</span><span class="p">)</span><span class="o">**</span><span class="n">symbol</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s see how it would look like in a sample circuit.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">demo_builder</span> <span class="o">=</span> <span class="n">CircuitLayerBuilder</span><span class="p">(</span><span class="n">data_qubits</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">.</span><span class="n">rect</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span>
                                   <span class="n">readout</span><span class="o">=</span><span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>

<span class="n">circuit</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Circuit</span><span class="p">()</span>
<span class="n">demo_builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">XX</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s">'xx'</span><span class="p">)</span>
<span class="n">SVGCircuit</span><span class="p">(</span><span class="n">circuit</span><span class="p">)</span>
</code></pre></div></div>

<center>
<img src="/images/sample-circuit.svg" style="zoom: 70%;" /><br />
<figcaption><i>Sample Circuit | Tensorflow Quantum</i></figcaption>
</center>

<p>As you can see, all data qubits (4 in this case) are connected with the readout qubit via an Ising ($XX$) Coupling gate.</p>

<h4 id="creating-quantum-model">Creating Quantum Model</h4>

<p>With the quantum layer class ready for use, we can create the quantum model for our QNN. Instead of only using a single Ising ($XX$) Coupling Gate, we’ll also add Ising ($ZZ$) Coupling Gate for every data qubit. These gates have their respective parameters, which our model will learn to optimize later on.</p>

<p>Notice that we’re adding two intial gates to the readout qubit, an $X$ gate to convert it into the state $1$, and an $H$ to set our qubit in superposition. After all the Ising Coupling gates, we’ll finally append another $H$ gate to our readout qubit to bring it out of superposition, before finally doing a $Z$-measurement to obtain the expectation value.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">create_quantum_model</span><span class="p">():</span>
    <span class="n">data_qubits</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">.</span><span class="n">rect</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>  <span class="c1"># a 4x4 grid.
</span>    <span class="n">readout</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>         <span class="c1"># a single qubit at [-1,-1]
</span>    <span class="n">circuit</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Circuit</span><span class="p">()</span>

    <span class="c1"># Prepare the readout qubit.
</span>    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">X</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>
    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">H</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>

    <span class="n">builder</span> <span class="o">=</span> <span class="n">CircuitLayerBuilder</span><span class="p">(</span>
        <span class="n">data_qubits</span> <span class="o">=</span> <span class="n">data_qubits</span><span class="p">,</span>
        <span class="n">readout</span><span class="o">=</span><span class="n">readout</span><span class="p">)</span>

    <span class="c1"># Then add layers (experiment by adding more).
</span>    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">XX</span><span class="p">,</span> <span class="s">"xx1"</span><span class="p">)</span>
    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">ZZ</span><span class="p">,</span> <span class="s">"zz1"</span><span class="p">)</span>

    <span class="c1"># Finally, prepare the readout qubit.
</span>    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">H</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Z</span><span class="p">(</span><span class="n">readout</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model_circuit</span><span class="p">,</span> <span class="n">model_readout</span> <span class="o">=</span> <span class="n">create_quantum_model</span><span class="p">()</span>
</code></pre></div></div>

<p>The model’s pretty huge since it has 17 qubits in total, and if we try to see how it looks when laid out on a flat circuit, it looks like the following:</p>

<center>
<img src="/images/qnn-model.png" style="zoom: 70%;" /><br />
<figcaption><i>The Quantum Neural Network Model</i></figcaption>
</center>

<h4 id="wrapping-model-circuit-in-tf-quantum-model">Wrapping Model-Circuit in TF-Quantum Model</h4>

<p>To bring all things we’ve built together, Tensorflow Quantum model/circuit interfaces with the normal Keras Sequential model. We’ll prepend an input layer which takes the encoded data from earlier, before finally feeding it into the quantum circuit. Since the parameters of the quantum circuits are the one we would like the model to learn upon, we’ll wrap it with the <code class="language-plaintext highlighter-rouge">tfq.layers.PQC</code> layer which returns the expectation value of the readout qubit.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">string</span><span class="p">),</span>
    <span class="n">tfq</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">PQC</span><span class="p">(</span><span class="n">model_circuit</span><span class="p">,</span> <span class="n">model_readout</span><span class="p">),</span>
<span class="p">])</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">PQC</code> layer will return its results within the range $[-1, 1]$, and using the hinge-loss is suitable although it requires us encoding the target labels like the following:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_train_hinge</span> <span class="o">=</span> <span class="mf">2.0</span><span class="o">*</span><span class="n">y_train_nocon</span><span class="o">-</span><span class="mf">1.0</span>
<span class="n">y_test_hinge</span> <span class="o">=</span> <span class="mf">2.0</span><span class="o">*</span><span class="n">y_test</span><span class="o">-</span><span class="mf">1.0</span>
</code></pre></div></div>

<p>It should be noted that we could instead shift the model’s output range to $[0, 1]$ and treat it as the probability the model assigns to class <code class="language-plaintext highlighter-rouge">3</code> to be used with the usual <code class="language-plaintext highlighter-rouge">tf.losses.BinaryCrossentropy</code> loss function.</p>

<p>We then specify a hinge accuracy metric which handles $[-1, 1]$ as the target labels argument.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">hinge_accuracy</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
    <span class="n">y_true</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mf">0.0</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mf">0.0</span>
    <span class="n">result</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span> <span class="o">==</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">result</span><span class="p">)</span>
</code></pre></div></div>

<p>Lastly, we’ll do the usual <code class="language-plaintext highlighter-rouge">model.compile()</code>, passing it our loss function, optimizer, and the metrics to be recorded.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span>
    <span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">Hinge</span><span class="p">(),</span>
    <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">Adam</span><span class="p">(),</span>
    <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">hinge_accuracy</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">summary</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
pqc (PQC)                    (None, 1)                 32
=================================================================
Total params: 32
Trainable params: 32
Non-trainable params: 0
_________________________________________________________________
None
</code></pre></div></div>

<h4 id="training-quantum-neural-network">Training Quantum Neural Network</h4>

<p>With everything in place and ready for training, we’ll begin the training of our model. Luckily, Tensorflow Quantum provides a default <code class="language-plaintext highlighter-rouge">Differentiator</code> which handles backpropagation through the quantum circuit, so we do not need to handle that manually. It is possible however, to provide it with our own <code class="language-plaintext highlighter-rouge">Differentiator</code> function, but we won’t be doing that here.</p>

<p>We’ll first decide the number of epochs, batch size, and the number of examples to be used for training. As there are quite many training images, we can always use a subset of it just to decrease training duration and to just see the model learn.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span>

<span class="n">NUM_EXAMPLES</span> <span class="o">=</span> <span class="mi">500</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_train_tfcirc_sub</span> <span class="o">=</span> <span class="n">x_train_tfcirc</span><span class="p">[:</span><span class="n">NUM_EXAMPLES</span><span class="p">]</span>
<span class="n">y_train_hinge_sub</span> <span class="o">=</span> <span class="n">y_train_hinge</span><span class="p">[:</span><span class="n">NUM_EXAMPLES</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">qnn_history</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span>
      <span class="n">x_train_tfcirc_sub</span><span class="p">,</span> <span class="n">y_train_hinge_sub</span><span class="p">,</span>
      <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
      <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span>
      <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
      <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test_tfcirc</span><span class="p">,</span> <span class="n">y_test_hinge</span><span class="p">))</span>

<span class="n">qnn_results</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test_tfcirc</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Train on 500 samples, validate on 1968 samples
Epoch 1/3
500/500 [==============================] - 301s 602ms/sample - loss: 0.9929 - hinge_accuracy: 0.6199 - val_loss: 0.9887 - val_hinge_accuracy: 0.6739
Epoch 2/3
500/500 [==============================] - 300s 600ms/sample - loss: 0.9849 - hinge_accuracy: 0.6777 - val_loss: 0.9808 - val_hinge_accuracy: 0.6774
Epoch 3/3
500/500 [==============================] - 301s 602ms/sample - loss: 0.9756 - hinge_accuracy: 0.6746 - val_loss: 0.9687 - val_hinge_accuracy: 0.6809
1968/1968 [==============================] - 34s 17ms/sample - loss: 0.9687 - hinge_accuracy: 0.6809
</code></pre></div></div>

<p>Note that the training accuracy reports the average over the epoch. While the validation accuracy is evaluated at the end of each epoch. Here, our model obtained about 0.68 validation hinge accuracy, and just like any other quantum or hybrid neural networks, this value varies from trials to trials. The highest accuracy I have obtained with the same exact subdataset and circuit was 0.80.</p>

<h3 id="classical-neural-network">Classical Neural Network</h3>

<p>A classical neural network will definitely outperform this QNN, even if we use a very simple classical <a href="https://en.wikipedia.org/wiki/Convolutional_neural_network">Convolutional Neural Network</a> (CNN). The tutorial showed an example of a CNN based off <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">LeNet</a> from a <a href="https://keras.io/examples/mnist_cnn/">Keras tutorial</a>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">create_classical_model</span><span class="p">():</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">()</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">activation</span><span class="o">=</span><span class="s">'relu'</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span><span class="mi">28</span><span class="p">,</span><span class="mi">1</span><span class="p">)))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">activation</span><span class="o">=</span><span class="s">'relu'</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.25</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Flatten</span><span class="p">())</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">'relu'</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">model</span>


<span class="n">model</span> <span class="o">=</span> <span class="n">create_classical_model</span><span class="p">()</span>
<span class="n">model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">BinaryCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
              <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">Adam</span><span class="p">(),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s">'accuracy'</span><span class="p">])</span>

<span class="n">model</span><span class="p">.</span><span class="n">summary</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 24, 24, 64)        18496
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 64)        0
_________________________________________________________________
dropout (Dropout)            (None, 12, 12, 64)        0
_________________________________________________________________
flatten (Flatten)            (None, 9216)              0
_________________________________________________________________
dense (Dense)                (None, 128)               1179776
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 129
=================================================================
Total params: 1,198,721
Trainable params: 1,198,721
Non-trainable params: 0
_________________________________________________________________
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span>
          <span class="n">y_train</span><span class="p">,</span>
          <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
          <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
          <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
          <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span>

<span class="n">cnn_results</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Train on 12049 samples, validate on 1968 samples
12049/12049 [==============================] - 7s 557us/sample - loss: 0.0397 - accuracy: 0.9854 - val_loss: 0.0053 - val_accuracy: 0.9990
1968/1968 [==============================] - 0s 144us/sample - loss: 0.0053 - accuracy: 0.9990
</code></pre></div></div>

<p>In just a single epoch, the classical CNN was able to achieve 0.99 validation accuracy. Although it looks like a simple CNN, it does however, get fed by the original 28x28 pixels image and has 1.2M parameters. Hence it’s not really fair to compare it to our QNN.</p>

<p>To put them into a fair level, we’ll create a 37-parameter classical neural network which also resizes the images to 4x4 pixels each.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">create_fair_classical_model</span><span class="p">():</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">()</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Flatten</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">,</span><span class="mi">1</span><span class="p">)))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">'relu'</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">model</span>


<span class="n">model</span> <span class="o">=</span> <span class="n">create_fair_classical_model</span><span class="p">()</span>
<span class="n">model</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">losses</span><span class="p">.</span><span class="n">BinaryCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
              <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">Adam</span><span class="p">(),</span>
              <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s">'accuracy'</span><span class="p">])</span>

<span class="n">model</span><span class="p">.</span><span class="n">summary</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_1 (Flatten)          (None, 16)                0
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 34
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 3
=================================================================
Total params: 37
Trainable params: 37
Non-trainable params: 0
_________________________________________________________________
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train_bin</span><span class="p">,</span>
          <span class="n">y_train_nocon</span><span class="p">,</span>
          <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
          <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
          <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
          <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test_bin</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span>

<span class="n">fair_nn_results</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test_bin</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Train on 11520 samples, validate on 1968 samples
Epoch 1/20
11520/11520 - 1s - loss: 0.7959 - accuracy: 0.4551 - val_loss: 0.7675 - val_accuracy: 0.4853
Epoch 2/20
11520/11520 - 0s - loss: 0.7290 - accuracy: 0.5030 - val_loss: 0.7130 - val_accuracy: 0.4868
Epoch 3/20
11520/11520 - 0s - loss: 0.6995 - accuracy: 0.5031 - val_loss: 0.6968 - val_accuracy: 0.4868
Epoch 4/20
11520/11520 - 0s - loss: 0.6918 - accuracy: 0.5034 - val_loss: 0.6925 - val_accuracy: 0.4878
Epoch 5/20
11520/11520 - 0s - loss: 0.6847 - accuracy: 0.5103 - val_loss: 0.6793 - val_accuracy: 0.4924
Epoch 6/20
11520/11520 - 0s - loss: 0.6553 - accuracy: 0.5845 - val_loss: 0.6425 - val_accuracy: 0.6316
Epoch 7/20
11520/11520 - 0s - loss: 0.5934 - accuracy: 0.6980 - val_loss: 0.5676 - val_accuracy: 0.7429
Epoch 8/20
11520/11520 - 0s - loss: 0.5298 - accuracy: 0.8106 - val_loss: 0.5105 - val_accuracy: 0.8216
Epoch 9/20
11520/11520 - 0s - loss: 0.4782 - accuracy: 0.8536 - val_loss: 0.4658 - val_accuracy: 0.8323
Epoch 10/20
11520/11520 - 0s - loss: 0.4386 - accuracy: 0.8595 - val_loss: 0.4318 - val_accuracy: 0.8333
Epoch 11/20
11520/11520 - 0s - loss: 0.4068 - accuracy: 0.8617 - val_loss: 0.4039 - val_accuracy: 0.8338
Epoch 12/20
11520/11520 - 0s - loss: 0.3811 - accuracy: 0.8635 - val_loss: 0.3813 - val_accuracy: 0.8338
Epoch 13/20
11520/11520 - 0s - loss: 0.3599 - accuracy: 0.8641 - val_loss: 0.3624 - val_accuracy: 0.8328
Epoch 14/20
11520/11520 - 0s - loss: 0.3421 - accuracy: 0.8648 - val_loss: 0.3465 - val_accuracy: 0.8328
Epoch 15/20
11520/11520 - 0s - loss: 0.3270 - accuracy: 0.8720 - val_loss: 0.3329 - val_accuracy: 0.8714
Epoch 16/20
11520/11520 - 0s - loss: 0.3140 - accuracy: 0.8874 - val_loss: 0.3210 - val_accuracy: 0.8714
Epoch 17/20
11520/11520 - 0s - loss: 0.3028 - accuracy: 0.8876 - val_loss: 0.3109 - val_accuracy: 0.8714
Epoch 18/20
11520/11520 - 0s - loss: 0.2931 - accuracy: 0.8876 - val_loss: 0.3019 - val_accuracy: 0.8714
Epoch 19/20
11520/11520 - 0s - loss: 0.2846 - accuracy: 0.8876 - val_loss: 0.2941 - val_accuracy: 0.8714
Epoch 20/20
11520/11520 - 0s - loss: 0.2771 - accuracy: 0.8873 - val_loss: 0.2872 - val_accuracy: 0.8714
1968/1968 [==============================] - 0s 78us/sample - loss: 0.2872 - accuracy: 0.8714
</code></pre></div></div>

<p>Unsurprisingly, the model performed better and arguably more stable than the QNN for obvious reasons. The data is very much classical, so its reasonable why a classical neural network would outperform a quantum one.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">qnn_accuracy</span> <span class="o">=</span> <span class="n">qnn_results</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">cnn_accuracy</span> <span class="o">=</span> <span class="n">cnn_results</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">fair_nn_accuracy</span> <span class="o">=</span> <span class="n">fair_nn_results</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

<span class="n">sns</span><span class="p">.</span><span class="n">barplot</span><span class="p">([</span><span class="s">"Quantum"</span><span class="p">,</span> <span class="s">"Classical, full"</span><span class="p">,</span> <span class="s">"Classical, fair"</span><span class="p">],</span>
            <span class="p">[</span><span class="n">qnn_accuracy</span><span class="p">,</span> <span class="n">cnn_accuracy</span><span class="p">,</span> <span class="n">fair_nn_accuracy</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;matplotlib.axes._subplots.AxesSubplot at 0x7f44e28fd518&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-qnn/output_55_1.png" style="zoom: 70%;" />
</center>

<h3 id="experiments">Experiments</h3>

<p>After learning how to create a QNN from the tutorial, I decided to play around with the number of parameters in the model. Instead of using only 1 Ising $(XX)$ Coupling Gate and 1 Ising $(ZZ)$ Coupling Gate, I’ve decided to use 2 of each kinds, which adds additional 32 parameters to the model, summing to 64 parameters in total.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">create_quantum_model</span><span class="p">():</span>
    <span class="n">data_qubits</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">.</span><span class="n">rect</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">readout</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">GridQubit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">circuit</span> <span class="o">=</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Circuit</span><span class="p">()</span>

    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">X</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>
    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">H</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>

    <span class="n">builder</span> <span class="o">=</span> <span class="n">CircuitLayerBuilder</span><span class="p">(</span>
        <span class="n">data_qubits</span> <span class="o">=</span> <span class="n">data_qubits</span><span class="p">,</span>
        <span class="n">readout</span><span class="o">=</span><span class="n">readout</span><span class="p">)</span>

    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">XX</span><span class="p">,</span> <span class="s">"xx1"</span><span class="p">)</span>
    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">XX</span><span class="p">,</span> <span class="s">"xx2"</span><span class="p">)</span>
    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">ZZ</span><span class="p">,</span> <span class="s">"zz1"</span><span class="p">)</span>
    <span class="n">builder</span><span class="p">.</span><span class="n">add_layer</span><span class="p">(</span><span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">ZZ</span><span class="p">,</span> <span class="s">"zz2"</span><span class="p">)</span>

    <span class="n">circuit</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cirq</span><span class="p">.</span><span class="n">H</span><span class="p">(</span><span class="n">readout</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">circuit</span><span class="p">,</span> <span class="n">cirq</span><span class="p">.</span><span class="n">Z</span><span class="p">(</span><span class="n">readout</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">summary</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
pqc (PQC)                    (None, 1)                 64
=================================================================
Total params: 64
Trainable params: 64
Non-trainable params: 0
_________________________________________________________________
None
</code></pre></div></div>

<p>I have also used a total of 1000 sample images instead of only 500, just for fun. The rest are kept identical.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span>

<span class="n">NUM_EXAMPLES</span> <span class="o">=</span> <span class="mi">1000</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">qnn_history</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span>
      <span class="n">x_train_tfcirc_sub</span><span class="p">,</span> <span class="n">y_train_hinge_sub</span><span class="p">,</span>
      <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
      <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span>
      <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
      <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_test_tfcirc</span><span class="p">,</span> <span class="n">y_test_hinge</span><span class="p">))</span>

<span class="n">qnn_results</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test_tfcirc</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Train on 1000 samples, validate on 1968 samples
Epoch 1/3
1000/1000 [==============================] - 1693s 2s/sample - loss: 0.9964 - hinge_accuracy: 0.6748 - val_loss: 0.9851 - val_hinge_accuracy: 0.7999
Epoch 2/3
1000/1000 [==============================] - 1704s 2s/sample - loss: 0.9271 - hinge_accuracy: 0.8066 - val_loss: 0.8194 - val_hinge_accuracy: 0.7964
Epoch 3/3
1000/1000 [==============================] - 1593s 2s/sample - loss: 0.6629 - hinge_accuracy: 0.7988 - val_loss: 0.5120 - val_hinge_accuracy: 0.7964
1968/1968 [==============================] - 50s 25ms/sample - loss: 0.5120 - hinge_accuracy: 0.7964
</code></pre></div></div>

<p>As you can see, the validation hinge accuracy this time is about 0.79 and a much lower validation loss, which is better than our 32-parameter model previously. It should be noted again that these values change from trials to trials, so a 1-time attempt do not represent the model’s performance entirely.</p>

<h2 id="closing-remarks">Closing Remarks</h2>

<h3 id="issues-with-quantum-neural-network">Issues with Quantum Neural Network</h3>

<p>As discussed in the previous post, there are still issues regarding QNNs and Quantum Computers in general. There are no analytical way to get the gradients of the quantum layers yet, and sometimes the circuit’s gradient vanishes as the model learns. There’s definitely a huge area of possible improvements as well as research to the possibilities of Quantum Neural Network in tackling the limits of a classical neural network.</p>

<h3 id="conclusion">Conclusion</h3>

<p>It’s been a ride learning how the Quantum Neural Network was implemented. It is very much different from how a classical neural network is implemented, and there are many factors to consider since the capabilities of a Quantum Computer and its simulators are still limited. However, it was still a mind-blowing experience to take a glimpse of the future potential of Quantum Computers and what it can offer to the Machine Learning domain.</p>

<h3 id="credits">Credits</h3>

<p>Portions of this page are modifications based on work created and <a href="https://developers.google.com/terms/site-policies">shared by Google</a> and used according to terms described in the <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons 4.0 Attribution License</a>.</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Quantum Machine Learning" /><category term="Quantum Computation" /><summary type="html"><![CDATA[Tensorflow is one of the most used deep learning frameworks today, bundled with many features for end-to-end deep learning processes. Recently, they have just announced a new library on top of Tensorflow, called Tensorflow Quantum. Tensorflow Quantum integrates with Cirq, which provides quantum computing algorithms, and the two works well to do tasks involving Quantum Machine Learning.]]></summary></entry><entry><title type="html">MNIST Classification with Hybrid Quantum-Classical Neural Network</title><link href="https://wilsonwongso.dev/posts/2020/07/mnist-hybrid-qnn/" rel="alternate" type="text/html" title="MNIST Classification with Hybrid Quantum-Classical Neural Network" /><published>2020-07-13T00:00:00+10:00</published><updated>2020-07-13T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/07/mnist-hybrid-qnn</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/07/mnist-hybrid-qnn/"><![CDATA[<p><a href="https://qiskit.org/">Qiskit</a> is IBM’s open-source framework to do quantum processes which provides users access to both simulators and real Quantum Computers. Today, the Quantum Computer available is still in the Noisy Intermediate-Scale Quantum (NISQ) era and is very much sensitive to any forms of interference. Unlike real Quantum Computers, <a href="https://www.ibm.com/quantum-computing/technology/simulator/">simulators provided by Qiskit</a> aren’t noisy and is great for prototyping.</p>

<h3 id="hybrid-quantum-classical-neural-network">Hybrid Quantum-Classical Neural Network</h3>

<p>Qiskit and PyTorch provides a way to connect classical neural networks with quantum circuit, thus creating a hybrid quantum-classical NN. A <a href="https://qiskit.org/textbook/ch-machine-learning/machine-learning-qiskit-pytorch.html">tutorial</a> is provided under the Qiskit textbook, and will be the basis of the code shown in this post.</p>

<h4 id="forward-pass">Forward Pass</h4>

<p>How a hybrid NN works in forward pass is shown in the following diagram:</p>

<center>
<img src="/images/neuralnetworkQC.png" style="zoom: 70%;" /><br />
<figcaption><i>Hybrid Quantum-Classical Neural Network | Qiskit Textbook</i></figcaption>
</center>

<p>As shown above, the neural network will have its usual classical layers at the start, a quantum “layer” in between, and followed by classical layers again. It is the parameters of the quantum layer which the neural network will learn to optimize.</p>

<p>The layers used in the classical part is arbitrary, however it should be noted that the output of the classical layers at the start should conform to the input of the quantum layer (which we’ll see later in code). Similarly, the output of the quantum layer should be in-line with the input of the following classical layer.</p>

<h4 id="backward-pass">Backward Pass</h4>

<p>This raises a question especially during the backpropagation process. The derivative of the quantum layer is required to perform gradient descent - a critical step to optimizing the model. To tackle the problem, we’ll be using the <a href="https://arxiv.org/pdf/1905.13311.pdf">parameter shift rule</a> to find its gradient, which is calculated as follows:</p>

<center>
<img src="/images/quantumgradient.png" style="zoom: 70%;" /><br />
<figcaption><i>Gradient of Quantum Layer | Qiskit Textbook</i></figcaption>
</center>

<p>The parameter shift rule is parallel to how finite difference works: making a small shift and calculating the change in the output with respect to the small shift. Details won’t be discussed here.</p>

<h3 id="mnist-classification">MNIST Classification</h3>

<p><a href="http://yann.lecun.com/exdb/mnist/">MNIST</a> is a go-to dataset for image classification as it is simple for a beginner. Similarly, we’ll be using MNIST to test out how our hybrid NN performs. In this case however, we’ll be only classifying 2 digits instead of the usual 10.</p>

<h2 id="code-classifying-0s-and-1s">Code: Classifying 0s and 1s</h2>

<h3 id="quantum-circuit">Quantum Circuit</h3>

<p>As mentioned above, we’ll create a quantum circuit whose parameter we’ll let the neural network tweak as it learns. The example given in the textbook is a very simple, 1-qubit circuit with two gates, a Hadamard and a $RY$ gate. A $RY$ rotation has a parameter called $\theta$ which is precisely the parameter to be optimized.</p>

<center>
<img src="/images/1qubitcirc.png" style="zoom: 70%;" /><br />
<figcaption><i>Quantum Circuit | Qiskit Textbook</i></figcaption>
</center>

<p>After going the two gates, the qubit is then measured. It is the result of this measurement which we’ll use as the final output of the neural network. A 1-qubit measurement has only two possible outputs, and the two possible outputs in our case corresponds to the two possible classes which an image belong to. To measure the $z$-basis output, we’ll be calculating the $\sigma_z$ expected value the same way as we would calculate expected value in statistics.</p>

\[\sigma_z = \sum_{i} z_i \cdot p(z_i)\]

<p>Later, we’ll specify the circuit how many shots or trials we’d like to make.</p>

<p>Let’s implement the circuit in Qiskit!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">QuantumCircuit</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_qubits</span><span class="p">,</span> <span class="n">backend</span><span class="p">,</span> <span class="n">shots</span><span class="p">):</span>
        <span class="c1"># --- Circuit definition ---
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span> <span class="o">=</span> <span class="n">qiskit</span><span class="p">.</span><span class="n">QuantumCircuit</span><span class="p">(</span><span class="n">n_qubits</span><span class="p">)</span>

        <span class="n">all_qubits</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_qubits</span><span class="p">)]</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">theta</span> <span class="o">=</span> <span class="n">qiskit</span><span class="p">.</span><span class="n">circuit</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="s">'theta'</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">.</span><span class="n">h</span><span class="p">(</span><span class="n">all_qubits</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">.</span><span class="n">barrier</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">.</span><span class="n">ry</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">theta</span><span class="p">,</span> <span class="n">all_qubits</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">.</span><span class="n">measure_all</span><span class="p">()</span>
        <span class="c1"># ---------------------------
</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">backend</span> <span class="o">=</span> <span class="n">backend</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">shots</span> <span class="o">=</span> <span class="n">shots</span>

    <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">thetas</span><span class="p">):</span>
        <span class="n">job</span> <span class="o">=</span> <span class="n">qiskit</span><span class="p">.</span><span class="n">execute</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">,</span>
                             <span class="bp">self</span><span class="p">.</span><span class="n">backend</span><span class="p">,</span>
                             <span class="n">shots</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">shots</span><span class="p">,</span>
                             <span class="n">parameter_binds</span> <span class="o">=</span> <span class="p">[{</span><span class="bp">self</span><span class="p">.</span><span class="n">theta</span><span class="p">:</span> <span class="n">theta</span><span class="p">}</span> <span class="k">for</span> <span class="n">theta</span> <span class="ow">in</span> <span class="n">thetas</span><span class="p">])</span>
        <span class="n">result</span> <span class="o">=</span> <span class="n">job</span><span class="p">.</span><span class="n">result</span><span class="p">().</span><span class="n">get_counts</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">_circuit</span><span class="p">)</span>

        <span class="n">counts</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">result</span><span class="p">.</span><span class="n">values</span><span class="p">()))</span>
        <span class="n">states</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">result</span><span class="p">.</span><span class="n">keys</span><span class="p">())).</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>

        <span class="c1"># Compute probabilities for each state
</span>        <span class="n">probabilities</span> <span class="o">=</span> <span class="n">counts</span> <span class="o">/</span> <span class="bp">self</span><span class="p">.</span><span class="n">shots</span>
        <span class="c1"># Get state expectation
</span>        <span class="n">expectation</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">states</span> <span class="o">*</span> <span class="n">probabilities</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">expectation</span><span class="p">])</span>
</code></pre></div></div>

<h3 id="testing-quantum-circuit">Testing Quantum Circuit</h3>

<p>Just for fun, the textbook gave a test implementation of the circuit if we were to run it as usual. We’ll specify that we’ll need 1 qubit, provide the simulator to be used, give it 100 shots and use $\pi$ as our angle.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">simulator</span> <span class="o">=</span> <span class="n">qiskit</span><span class="p">.</span><span class="n">Aer</span><span class="p">.</span><span class="n">get_backend</span><span class="p">(</span><span class="s">'qasm_simulator'</span><span class="p">)</span>

<span class="n">circuit</span> <span class="o">=</span> <span class="n">QuantumCircuit</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">simulator</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">'Expected value for rotation pi: {}'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">circuit</span><span class="p">.</span><span class="n">run</span><span class="p">([</span><span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">])[</span><span class="mi">0</span><span class="p">]))</span>
<span class="n">circuit</span><span class="p">.</span><span class="n">_circuit</span><span class="p">.</span><span class="n">draw</span><span class="p">(</span><span class="n">output</span><span class="o">=</span><span class="s">'mpl'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Expected value for rotation pi: 0.5
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_6_1.png" style="zoom: 70%;" />
</center>

<h3 id="quantum-classical-class">Quantum-Classical Class</h3>

<p>After creating the designated circuit, we can utilize it to create a hybrid class/layer with PyTorch. We specify the forward pass to be pretty much running the circuit, and the backward pass to be the parameter shift rule we discussed earlier.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">HybridFunction</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">quantum_circuit</span><span class="p">,</span> <span class="n">shift</span><span class="p">):</span>
        <span class="s">""" Forward pass computation """</span>
        <span class="n">ctx</span><span class="p">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">shift</span>
        <span class="n">ctx</span><span class="p">.</span><span class="n">quantum_circuit</span> <span class="o">=</span> <span class="n">quantum_circuit</span>

        <span class="n">expectation_z</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">quantum_circuit</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="nb">input</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">tolist</span><span class="p">())</span>
        <span class="n">result</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">expectation_z</span><span class="p">])</span>
        <span class="n">ctx</span><span class="p">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">result</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">result</span>

    <span class="o">@</span><span class="nb">staticmethod</span>
    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">):</span>
        <span class="s">""" Backward pass computation """</span>
        <span class="nb">input</span><span class="p">,</span> <span class="n">expectation_z</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">saved_tensors</span>
        <span class="n">input_list</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">input</span><span class="p">.</span><span class="n">tolist</span><span class="p">())</span>

        <span class="n">shift_right</span> <span class="o">=</span> <span class="n">input_list</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_list</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="n">ctx</span><span class="p">.</span><span class="n">shift</span>
        <span class="n">shift_left</span> <span class="o">=</span> <span class="n">input_list</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_list</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="n">ctx</span><span class="p">.</span><span class="n">shift</span>

        <span class="n">gradients</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_list</span><span class="p">)):</span>
            <span class="n">expectation_right</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">quantum_circuit</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">shift_right</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
            <span class="n">expectation_left</span>  <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">quantum_circuit</span><span class="p">.</span><span class="n">run</span><span class="p">(</span><span class="n">shift_left</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>

            <span class="n">gradient</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">expectation_right</span><span class="p">])</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">expectation_left</span><span class="p">])</span>
            <span class="n">gradients</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">gradient</span><span class="p">)</span>
        <span class="n">gradients</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">gradients</span><span class="p">]).</span><span class="n">T</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">gradients</span><span class="p">]).</span><span class="nb">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">grad_output</span><span class="p">.</span><span class="nb">float</span><span class="p">(),</span> <span class="bp">None</span><span class="p">,</span> <span class="bp">None</span>
</code></pre></div></div>

<p>With that we can create an actual PyTorch layer which inherits from <code class="language-plaintext highlighter-rouge">nn.Module</code> which just applies whatever we’ve implemented in <code class="language-plaintext highlighter-rouge">HybridFunction</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Hybrid</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">,</span> <span class="n">shots</span><span class="p">,</span> <span class="n">shift</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Hybrid</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">quantum_circuit</span> <span class="o">=</span> <span class="n">QuantumCircuit</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">backend</span><span class="p">,</span> <span class="n">shots</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">shift</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">HybridFunction</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">quantum_circuit</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">shift</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="loading-data">Loading Data</h3>

<h4 id="training-dataset">Training Dataset</h4>

<p>As mentioned, we’ll use MNIST but only two of its classes, specifically 0s and 1s. We’ll load up the dataset from PyTorch datasets for training and testing purposes. Only 100 samples were used for training and 50 for testing in the example.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples</span> <span class="o">=</span> <span class="mi">100</span>

<span class="n">X_train</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s">'./data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                         <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">()]))</span>

<span class="c1"># Leaving only labels 0 and 1
</span><span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">],</span>
                <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">])</span>

<span class="n">X_train</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>

<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples_show</span> <span class="o">=</span> <span class="mi">6</span>

<span class="n">data_iter</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="n">n_samples_show</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>

<span class="k">while</span> <span class="n">n_samples_show</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    <span class="n">images</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data_iter</span><span class="p">.</span><span class="n">__next__</span><span class="p">()</span>

    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">().</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">)</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([])</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_yticks</span><span class="p">([])</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Labeled: {}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">targets</span><span class="p">.</span><span class="n">item</span><span class="p">()))</span>

    <span class="n">n_samples_show</span> <span class="o">-=</span> <span class="mi">1</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_13_0.png" style="zoom: 70%;" />
</center>

<h4 id="testing-dataset">Testing Dataset</h4>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples</span> <span class="o">=</span> <span class="mi">50</span>

<span class="n">X_test</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s">'./data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                        <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">()]))</span>

<span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">],</span>
                <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">])</span>

<span class="n">X_test</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>

<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="hybrid-neural-network">Hybrid Neural Network</h3>

<p>With most of the things in-place, we can begin to create our model. The classical layers we’ll use are normal convolution, dropout and linear layers. Notice that the final linear layer <code class="language-plaintext highlighter-rouge">fc2</code> only has 1 output since our quantum layer has only 1 parameter. Also, the final output of the forward pass concatenates the two probabilities into one tensor which we’ll later pass to our loss function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout2d</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">64</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">hybrid</span> <span class="o">=</span> <span class="n">Hybrid</span><span class="p">(</span><span class="n">qiskit</span><span class="p">.</span><span class="n">Aer</span><span class="p">.</span><span class="n">get_backend</span><span class="p">(</span><span class="s">'qasm_simulator'</span><span class="p">),</span> <span class="mi">100</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">hybrid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">x</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="training-neural-network">Training Neural Network</h3>

<p>Finally, we’ll train our model just as we would train a normal image classification model. We’ve implemented all the backward pass processes in the quantum layer, so doing <code class="language-plaintext highlighter-rouge">loss.backward()</code> would correspond to the parameter shift rule previously.</p>

<p>We’ll train for 20 epochs and record the loss after each iteration.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">loss_list</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Hybrid NN Training Convergence'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Training Iterations'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Neg Log Likelihood Loss'</span><span class="p">)</span><span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">NLLLoss</span><span class="p">()</span>

<span class="n">epochs</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">loss_list</span> <span class="o">=</span> <span class="p">[]</span>

<span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
    <span class="n">total_loss</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
        <span class="c1"># Forward pass
</span>        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
        <span class="c1"># Calculating loss
</span>        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
        <span class="c1"># Backward pass
</span>        <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="c1"># Optimize the weights
</span>        <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

        <span class="n">total_loss</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>
    <span class="n">loss_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">total_loss</span><span class="p">))</span>
    <span class="k">print</span><span class="p">(</span><span class="s">'Training [{:.0f}%]</span><span class="se">\t</span><span class="s">Loss: {:.4f}'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
        <span class="mf">100.</span> <span class="o">*</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">loss_list</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Training [5%]	Loss: -0.6274
Training [10%]	Loss: -0.7605
Training [15%]	Loss: -0.7898
Training [20%]	Loss: -0.8343
Training [25%]	Loss: -0.8573
Training [30%]	Loss: -0.8514
Training [35%]	Loss: -0.8776
Training [40%]	Loss: -0.8414
Training [45%]	Loss: -0.8811
Training [50%]	Loss: -0.8226
Training [55%]	Loss: -0.8174
Training [60%]	Loss: -0.8588
Training [65%]	Loss: -0.8629
Training [70%]	Loss: -0.8767
Training [75%]	Loss: -0.8635
Training [80%]	Loss: -0.8688
Training [85%]	Loss: -0.8795
Training [90%]	Loss: -0.9021
Training [95%]	Loss: -0.8732
Training [100%]	Loss: -0.8694
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">loss_list</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Hybrid NN Training Convergence'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Training Iterations'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Neg Log Likelihood Loss'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Text(0, 0.5, 'Neg Log Likelihood Loss')
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_20_1.png" style="zoom: 70%;" />
</center>

<h3 id="testing-neural-network">Testing Neural Network</h3>

<p>As seen in the diagram above, our loss has gradually decreased and it seems that the model had learned well. To see how it fairs, let’s test it out with the test data we’ve set apart earlier.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>

    <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">):</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>

        <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="p">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="p">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">)).</span><span class="nb">sum</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>

        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
        <span class="n">total_loss</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>

    <span class="k">print</span><span class="p">(</span><span class="s">'Performance on test data:</span><span class="se">\n\t</span><span class="s">Loss: {:.4f}</span><span class="se">\n\t</span><span class="s">Accuracy: {:.1f}%'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
        <span class="nb">sum</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">total_loss</span><span class="p">),</span>
        <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span>
        <span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Performance on test data:
	Loss: -0.8713
	Accuracy: 100.0%
</code></pre></div></div>

<p>Notice that the model has achieved 100% accuracy with the small test dataset, which is reasonable.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples_show</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="n">n_samples_show</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>

<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="n">n_samples_show</span><span class="p">:</span>
            <span class="k">break</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>

        <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">imshow</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">().</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">)</span>

        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([])</span>
        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_yticks</span><span class="p">([])</span>
        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Predicted {}'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">pred</span><span class="p">.</span><span class="n">item</span><span class="p">()))</span>

        <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_24_0.png" style="zoom: 70%;" />
</center>

<h2 id="code-classifying-3s-and-7s">Code: Classifying 3s and 7s</h2>

<p>With what the model can achieve, I tried to change the dataset used. Instead of using 0s and 1s which look fairly different from each other, I tried to replace them with 3s and 7s to see how the model performs. The processes except the data-loading is pretty much identical.</p>

<h3 id="loading-data-1">Loading Data</h3>

<h4 id="training-dataset-1">Training Dataset</h4>

<p>Here we’ll specify that we want 3s and 7s, and encode their labels to 0 and 1 respectively.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples</span> <span class="o">=</span> <span class="mi">100</span>

<span class="n">X_train</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s">'./data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                         <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">()]))</span>

<span class="c1"># Leaving only labels 3 and 7
</span><span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">3</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">],</span>
                <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">7</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">])</span>

<span class="n">X_train</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="c1"># Encode into 0 and 1
</span><span class="n">X_train</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span> <span class="n">X_train</span><span class="p">.</span><span class="n">targets</span><span class="p">)))</span>

<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples_show</span> <span class="o">=</span> <span class="mi">6</span>

<span class="n">data_iter</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="n">n_samples_show</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>

<span class="k">while</span> <span class="n">n_samples_show</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    <span class="n">images</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data_iter</span><span class="p">.</span><span class="n">__next__</span><span class="p">()</span>

    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">().</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">)</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([])</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_yticks</span><span class="p">([])</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">n_samples_show</span> <span class="o">-</span> <span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Labeled: {}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">targets</span><span class="p">.</span><span class="n">item</span><span class="p">()))</span>

    <span class="n">n_samples_show</span> <span class="o">-=</span> <span class="mi">1</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_28_0.png" style="zoom: 70%;" />
</center>

<h4 id="testing-dataset-1">Testing Dataset</h4>

<p>Exact same process of specifying 3s and 7s and encoding the label.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples</span> <span class="o">=</span> <span class="mi">50</span>

<span class="n">X_test</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="s">'./data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                        <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">()]))</span>

<span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">3</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">],</span>
                <span class="n">np</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">==</span> <span class="mi">7</span><span class="p">)[</span><span class="mi">0</span><span class="p">][:</span><span class="n">n_samples</span><span class="p">])</span>

<span class="n">X_test</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">X_test</span><span class="p">.</span><span class="n">targets</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">X_test</span><span class="p">.</span><span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span> <span class="n">X_test</span><span class="p">.</span><span class="n">targets</span><span class="p">)))</span>

<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="training-neural-network-1">Training Neural Network</h3>

<p>I used the exact same training loop as before.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">NLLLoss</span><span class="p">()</span>

<span class="n">epochs</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">loss_list</span> <span class="o">=</span> <span class="p">[]</span>

<span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
    <span class="n">total_loss</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
        <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

        <span class="n">total_loss</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>
    <span class="n">loss_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">total_loss</span><span class="p">))</span>
    <span class="k">print</span><span class="p">(</span><span class="s">'Training [{:.0f}%]</span><span class="se">\t</span><span class="s">Loss: {:.4f}'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
        <span class="mf">100.</span> <span class="o">*</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">loss_list</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Training [5%]	Loss: -0.4957
Training [10%]	Loss: -0.5000
Training [15%]	Loss: -0.4913
Training [20%]	Loss: -0.5009
Training [25%]	Loss: -0.5024
Training [30%]	Loss: -0.4997
Training [35%]	Loss: -0.6483
Training [40%]	Loss: -0.6767
Training [45%]	Loss: -0.6585
Training [50%]	Loss: -0.6675
Training [55%]	Loss: -0.7013
Training [60%]	Loss: -0.7226
Training [65%]	Loss: -0.7191
Training [70%]	Loss: -0.7031
Training [75%]	Loss: -0.7167
Training [80%]	Loss: -0.7193
Training [85%]	Loss: -0.7220
Training [90%]	Loss: -0.7300
Training [95%]	Loss: -0.7376
Training [100%]	Loss: -0.7249
</code></pre></div></div>

<p>Somehow, the model’s loss converged a bit smoother than it did before, although a huge jump did occur in the first few iterations.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">loss_list</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Hybrid NN Training Convergence'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Training Iterations'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Neg Log Likelihood Loss'</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Text(0, 0.5, 'Neg Log Likelihood Loss')
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_34_1.png" style="zoom: 70%;" />
</center>

<h3 id="testing-neural-network-1">Testing Neural Network</h3>

<p>Similarly, same process of testing the results as I did before, except having to decode 0 and 1 into 3s and 7s just for convenience.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>

    <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">):</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>

        <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="p">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="p">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">)).</span><span class="nb">sum</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>

        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
        <span class="n">total_loss</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>

    <span class="k">print</span><span class="p">(</span><span class="s">'Performance on test data:</span><span class="se">\n\t</span><span class="s">Loss: {:.4f}</span><span class="se">\n\t</span><span class="s">Accuracy: {:.1f}%'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
        <span class="nb">sum</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">total_loss</span><span class="p">),</span>
        <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span>
        <span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Performance on test data:
	Loss: -0.7454
	Accuracy: 91.0%
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_samples_show</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="n">n_samples_show</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>

<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">count</span> <span class="o">==</span> <span class="n">n_samples_show</span><span class="p">:</span>
            <span class="k">break</span>
        <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>

        <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">imshow</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">().</span><span class="n">squeeze</span><span class="p">(),</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'gray'</span><span class="p">)</span>

        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([])</span>
        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_yticks</span><span class="p">([])</span>
        <span class="n">axes</span><span class="p">[</span><span class="n">count</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Predicted {}'</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="mi">3</span> <span class="k">if</span> <span class="n">pred</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">7</span><span class="p">))</span>

        <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/mnist-hybrid-qnn/output_37_0.png" style="zoom: 70%;" />
</center>

<p>Notice that the model has achieved a lower testing accuracy due to numerous possible reasons, but details won’t matter.</p>

<h2 id="closing-remarks">Closing Remarks</h2>

<h3 id="benefits-of-hybrid-neural-networks">Benefits of Hybrid Neural Networks</h3>

<p>All the circuits we’ve used are classically simulatable, which means we’re not leveraging the potential of quantum computation, such as <em>entanglement</em>. The authors of the textbook also mentioned that the model would’ve trained equally, or even better without the quantum layer.</p>

<p>Without us utilizing quantum phenomenas/properties, the results will probably be similar to that of using a normal, classical neural network. However for now, we can always test out these kinds of networks to see if there are in fact possible benefits of using such kinds of network. It would require a more sophisticated quantum layer to possibly achieve greater “quantum advantage”.</p>

<h3 id="experienced-issues">Experienced Issues</h3>

<p>Although the results look reasonable here in this post, I did get questionable results in one of my earliest tries. Despite using a simulator, it seemed that the network at a certain trial <strong>didn’t learn</strong>, or the qubit’s results were just very unlucky after each measurement. The loss stayed at around $-0.5$ after 20 epochs, and achieved only 50% accuracy during testing - no better than a random guess. Here is the loss graph for the network I’ve just mentioned:</p>

<center>
<img src="/images/fluctuativeloss.png" style="zoom: 70%;" /><br />
<figcaption><i>Fluctuative Loss of a Hybrid NN</i></figcaption>
</center>

<h3 id="overall-results">Overall Results</h3>

<p>It’s very amusing to see how we can fuse quantum and classical layers together to create such neural networks, even if there’s no particular advantage of doing so. Regardless, we can always move up from here and apply the simpler concepts which the textbook has shown and see whether we can put hybrid NN to good use in the future!</p>

<h3 id="credits">Credits</h3>

<p>Asfaw, A., Bello, L., Ben-Haim, Y., Bravyi, S., Capelluto, L., Vazquez, A. C., . . . Wootton, J. (2020). <em>Learn Quantum Computation Using Qiskit</em>. Retrieved from http://community.qiskit.org/textbook</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Quantum Machine Learning" /><category term="Quantum Computation" /><summary type="html"><![CDATA[Qiskit is IBM’s open-source framework to do quantum processes which provides users access to both simulators and real Quantum Computers. Today, the Quantum Computer available is still in the Noisy Intermediate-Scale Quantum (NISQ) era and is very much sensitive to any forms of interference. Unlike real Quantum Computers, simulators provided by Qiskit aren’t noisy and is great for prototyping.]]></summary></entry><entry><title type="html">Color Restoration with Generative Adversarial Network</title><link href="https://wilsonwongso.dev/posts/2020/07/color-restoration-gan/" rel="alternate" type="text/html" title="Color Restoration with Generative Adversarial Network" /><published>2020-07-10T00:00:00+10:00</published><updated>2020-07-10T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/07/color-restoration-gan</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/07/color-restoration-gan/"><![CDATA[<p><a href="https://www.fast.ai/">Fast.ai</a> has a two-part Deep Learning Course, the first being <a href="https://course.fast.ai/">Practical Deep Learning for Coders</a>, and the second being <a href="https://course.fast.ai/part2">Deep Learning from the Foundations</a>, both having different approaches and intended for different audiences. In the <a href="https://course.fast.ai/videos/?lesson=7">7th lecture</a> of Part 1, Jeremy Howard taught a lot about modern architectures such as <a href="https://arxiv.org/abs/1512.03385">Residual Network (ResNet)</a> , <a href="https://arxiv.org/abs/1505.04597">U-Net</a>, and <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Network (GAN)</a>.</p>

<h3 id="generative-adversarial-networks">Generative Adversarial Networks</h3>

<p>GANs were first invented by Ian Goodfellow, one of the modern figures in the Deep Learning world. GANs could be used for various tasks such as <a href="https://www.tensorflow.org/tutorials/generative/style_transfer">Style Transfer</a>, <a href="https://www.tensorflow.org/tutorials/generative/pix2pix">Pix2Pix</a>, create <a href="https://www.tensorflow.org/tutorials/generative/cyclegan">CycleGAN</a>, etc. Today what I’ll be experimenting with is Image Restoration.</p>

<center>
<img src="/images/stylized-image.png" style="zoom: 70%;" /><br />
<figcaption>Style Transfer Result | Tensorflow Tutorials</figcaption>
</center>

<h3 id="image-restoration">Image Restoration</h3>

<p>There are different elements of an image which one can attempt to restore, and the example shown by Jeremy was restoring low resolution images into higher resolution images, which produces something like the following</p>

<center>
<img src="/images/restored-image.png" style="zoom: 70%;" /><br />
<figcaption>Image Restoration Result | fast.ai</figcaption>
</center>

<p>Jeremy also mentioned that GANs would also be capable of not only restoring an image’s resolution, but other elements such as clearing JPEG-like artifacts, different kinds of noise, or even restoring colors. And with that, I immediately hooked to finish the lecture and try out what I’ve learned, and thus came this project.</p>

<h3 id="color-restoration">Color Restoration</h3>

<p>Instead of turning low resolution images to high resolution images, I instead wanted to build a network which will be able to recolor black and white images. The approach is to do so is still similar in terms of how a GAN works, except with a few tweaks which we’ll discuss further down.</p>

<h3 id="code-source">Code Source</h3>

<p>Since it is the first time I’ve worked with generative networks like GANs, I decided to base my code heavily on a fast.ai notebook, <a href="https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb">lesson7-superres-gan.ipynb</a>.</p>

<p>The code provided below isn’t complete and only the important blocks of code were taken.</p>

<h2 id="the-gan-approach">The GAN Approach</h2>

<p>A GAN is sort of like a game between two entities, one being the <strong>artist</strong> (formally generator) and the other being the <strong>critic</strong> (formally discriminator). Both of them have their own respective roles: the artist has to produce an image, while the critic has to decide whether the image produced by the artist is a <em>real</em> image or a <em>fake/generated</em> image.</p>

<p>The two of them have to get better at what they do, the critic has to get better at differentiating real from fake images, while the artist has to improve the image produced to <em>fool</em> the critic. The implementation of this concept to a task like image restoration is pretty much like the aforementioned. That is, the artist has to produce a <strong>higher resolution</strong> image from the low resolution image, while the critic also learns to <strong>distinguish</strong> between the two possibilities.</p>

<p>Now, to apply that to color restoration, instead of differentiating low resolution from high resolution images, the critic has to classify artist-generated images from colored images, and while doing so the artist has to learn how to better recolor the images it produces to outsmart the critic.</p>

<h3 id="data-modification">Data Modification</h3>

<p>In order to build a network that is able to both learn to recolor images and to classify real from fake images, we need to provide it two sets of data, namely a colored image and its corresponding black-and-white image. To do so, we used the <a href="http://www.robots.ox.ac.uk/~vgg/data/pets/">Pets dataset from Oxford IIT</a> which are colored, and created a function to grayscale the images. Jeremy called the function to do such task as a <em>crappifier</em>, which in our case only grayscales the images. Once we have our colored and grayscaled images, we can use it later to train the network.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageDraw</span><span class="p">,</span> <span class="n">ImageFont</span>

<span class="k">class</span> <span class="nc">crappifier</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path_lr</span><span class="p">,</span> <span class="n">path_hr</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">path_lr</span> <span class="o">=</span> <span class="n">path_lr</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">path_hr</span> <span class="o">=</span> <span class="n">path_hr</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
        <span class="n">dest</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">path_lr</span><span class="o">/</span><span class="n">fn</span><span class="p">.</span><span class="n">relative_to</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">path_hr</span><span class="p">)</span>
        <span class="n">dest</span><span class="p">.</span><span class="n">parent</span><span class="p">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">img</span> <span class="o">=</span> <span class="n">PIL</span><span class="p">.</span><span class="n">Image</span><span class="p">.</span><span class="nb">open</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
        <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="p">.</span><span class="n">convert</span><span class="p">(</span><span class="s">'L'</span><span class="p">)</span>
        <span class="n">img</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">dest</span><span class="p">,</span> <span class="n">quality</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
</code></pre></div></div>

<center>
<img src="/images/grayscaled-image.png" style="zoom: 70%;" /><br />
<figcaption>Grayscaled Images</figcaption>
</center>

<h3 id="pre-train-generatorartist">Pre-train Generator/Artist</h3>

<p>Now, we will begin to train our generator first before using it in a GAN. The architecture we’ll use is a U-Net, with ResNet34 as its base model and all it’s trained to do is to recolor the images so it looks more like its colored-counterpart. Notice also that we’re using Mean Squared Error or <code class="language-plaintext highlighter-rouge">MSELossFlat</code> as our loss function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">arch</span> <span class="o">=</span> <span class="n">models</span><span class="p">.</span><span class="n">resnet34</span>
<span class="n">loss_gen</span> <span class="o">=</span> <span class="n">MSELossFlat</span><span class="p">()</span>

<span class="n">learn_gen</span> <span class="o">=</span> <span class="n">unet_learner</span><span class="p">(</span><span class="n">data_gen</span><span class="p">,</span> <span class="n">arch</span><span class="p">,</span> <span class="n">wd</span><span class="o">=</span><span class="n">wd</span><span class="p">,</span> <span class="n">blur</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">norm_type</span><span class="o">=</span><span class="n">NormType</span><span class="p">.</span><span class="n">Weight</span><span class="p">,</span>
                         <span class="n">self_attention</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">y_range</span><span class="o">=</span><span class="n">y_range</span><span class="p">,</span> <span class="n">loss_func</span><span class="o">=</span><span class="n">loss_gen</span><span class="p">)</span>
</code></pre></div></div>

<p>Once we have the generative model, we can train the model head for a few epochs, unfreeze, and train for several more epochs.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn_gen</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">pct_start</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>0.109306</td>
      <td>0.111038</td>
      <td>02:37</td>
    </tr>
    <tr>
      <td>1</td>
      <td>0.096312</td>
      <td>0.102479</td>
      <td>02:40</td>
    </tr>
  </tbody>
</table>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn_gen</span><span class="p">.</span><span class="n">unfreeze</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn_gen</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="nb">slice</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">,</span><span class="mf">1e-3</span><span class="p">))</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>0.089206</td>
      <td>0.100583</td>
      <td>02:41</td>
    </tr>
    <tr>
      <td>1</td>
      <td>0.087562</td>
      <td>0.094716</td>
      <td>02:44</td>
    </tr>
    <tr>
      <td>2</td>
      <td>0.086839</td>
      <td>0.094106</td>
      <td>02:45</td>
    </tr>
  </tbody>
</table>

<p>The resulting generated images after a total of 5 epochs looks like the following</p>

<center>
<img src="/images/generated-image.png" style="zoom: 70%;" /><br />
<figcaption>Generated Images</figcaption>
</center>

<p>As you can see, the generator did poorly on some areas of the image, while it did great in others. Regardless, we’ll save those generated images to be used as the fake images dataset for the critic to learn from.</p>

<h3 id="train-discriminatorcritic">Train Discriminator/Critic</h3>

<p>After generating two sets of images, we’ll feed the data to a critic and let it learn to distinguish between real images from the artist-generated images. Below is a sample batch of data, where the real images are labelled simply as <code class="language-plaintext highlighter-rouge">images</code> and the generated ones as <code class="language-plaintext highlighter-rouge">image_gen</code></p>

<center>
<img src="/images/critic-data.png" style="zoom: 70%;" /><br />
<figcaption>Real and Generated Images</figcaption>
</center>

<p>To create the critic, we’ll be using fast.ai’s built-in <code class="language-plaintext highlighter-rouge">gan_critic</code>, which is just a simple Convolutional Neural Network with residual blocks. Unlike the generator, the loss function we’ll use is Binary Cross Entropy, since we only have two possible predictions, and also wrap it with <code class="language-plaintext highlighter-rouge">AdaptiveLoss</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">loss_critic</span> <span class="o">=</span> <span class="n">AdaptiveLoss</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">())</span>

<span class="n">learn_critic</span> <span class="o">=</span> <span class="n">Learner</span><span class="p">(</span><span class="n">data_crit</span><span class="p">,</span> <span class="n">gan_critic</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">=</span><span class="n">accuracy_thresh_expand</span><span class="p">,</span> <span class="n">loss_func</span><span class="o">=</span><span class="n">loss_critic</span><span class="p">,</span> <span class="n">wd</span><span class="o">=</span><span class="n">wd</span><span class="p">)</span>
</code></pre></div></div>

<p>Once the Learner has been created, we can proceed with training the critic for several epochs.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn_critic</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mf">1e-3</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>accuracy_thresh_expand</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>0.170356</td>
      <td>0.105095</td>
      <td>0.958804</td>
      <td>03:34</td>
    </tr>
    <tr>
      <td>1</td>
      <td>0.041809</td>
      <td>0.022646</td>
      <td>0.992365</td>
      <td>03:27</td>
    </tr>
    <tr>
      <td>2</td>
      <td>0.026520</td>
      <td>0.013480</td>
      <td>0.996638</td>
      <td>03:26</td>
    </tr>
    <tr>
      <td>3</td>
      <td>0.011859</td>
      <td>0.005585</td>
      <td>0.999117</td>
      <td>03:25</td>
    </tr>
    <tr>
      <td>4</td>
      <td>0.012674</td>
      <td>0.005655</td>
      <td>0.999288</td>
      <td>03:25</td>
    </tr>
    <tr>
      <td>5</td>
      <td>0.013518</td>
      <td>0.005413</td>
      <td>0.999288</td>
      <td>03:24</td>
    </tr>
  </tbody>
</table>

<h3 id="gan">GAN</h3>

<p>With both of the generator and the critic pretrained, we can finally use both of them together and commence the game of outsmarting each other found in GANs. We will be utilizing <code class="language-plaintext highlighter-rouge">AdaptiveGANSwitcher</code>, which basically goes switches between generator to critic or vice versa when the loss goes below a certain threshold.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">switcher</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">AdaptiveGANSwitcher</span><span class="p">,</span> <span class="n">critic_thresh</span><span class="o">=</span><span class="mf">0.65</span><span class="p">)</span>
</code></pre></div></div>

<p>Wrapping both the generator and the critic inside a GAN learner:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span> <span class="o">=</span> <span class="n">GANLearner</span><span class="p">.</span><span class="n">from_learners</span><span class="p">(</span><span class="n">learn_gen</span><span class="p">,</span> <span class="n">learn_crit</span><span class="p">,</span> <span class="n">weights_gen</span><span class="o">=</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span><span class="mf">50.</span><span class="p">),</span> <span class="n">show_img</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">switcher</span><span class="o">=</span><span class="n">switcher</span><span class="p">,</span>
                                 <span class="n">opt_func</span><span class="o">=</span><span class="n">partial</span><span class="p">(</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.</span><span class="p">,</span><span class="mf">0.99</span><span class="p">)),</span> <span class="n">wd</span><span class="o">=</span><span class="n">wd</span><span class="p">)</span>
</code></pre></div></div>

<p>A particular callback we’ll use is called <code class="language-plaintext highlighter-rouge">GANDiscriminativeLR</code>, which handles multiplying the learning rate for the critic.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">callback_fns</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">partial</span><span class="p">(</span><span class="n">GANDiscriminativeLR</span><span class="p">,</span> <span class="n">mult_lr</span><span class="o">=</span><span class="mf">5.</span><span class="p">))</span>
</code></pre></div></div>

<p>Finally, we can train the GAN for 40 rounds before we use a larger image size to train for another 10 rounds.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lr</span> <span class="o">=</span> <span class="mf">1e-4</span>
<span class="n">learn</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="mi">40</span><span class="p">,</span> <span class="n">lr</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>gen_loss</th>
      <th>disc_loss</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>3.718557</td>
      <td>3.852783</td>
      <td>03:27</td>
    </tr>
    <tr>
      <td>1</td>
      <td>3.262025</td>
      <td>3.452096</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>2</td>
      <td>3.241105</td>
      <td>3.499610</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>3</td>
      <td>3.098072</td>
      <td>3.511492</td>
      <td>03:31</td>
    </tr>
    <tr>
      <td>4</td>
      <td>3.161309</td>
      <td>3.211511</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>5</td>
      <td>3.108723</td>
      <td>2.590987</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>6</td>
      <td>3.049329</td>
      <td>3.215695</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>7</td>
      <td>3.156122</td>
      <td>3.255158</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>8</td>
      <td>3.039921</td>
      <td>3.255423</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>9</td>
      <td>3.136142</td>
      <td>3.109873</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>10</td>
      <td>2.969435</td>
      <td>3.096309</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>11</td>
      <td>2.967517</td>
      <td>3.532753</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>12</td>
      <td>3.066835</td>
      <td>3.302504</td>
      <td>03:28</td>
    </tr>
    <tr>
      <td>13</td>
      <td>2.979472</td>
      <td>3.147814</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>14</td>
      <td>2.848181</td>
      <td>3.229101</td>
      <td>03:29</td>
    </tr>
    <tr>
      <td>15</td>
      <td>2.981036</td>
      <td>3.370961</td>
      <td>03:30</td>
    </tr>
    <tr>
      <td>16</td>
      <td>2.874022</td>
      <td>3.646701</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>17</td>
      <td>2.816335</td>
      <td>3.517284</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>18</td>
      <td>2.886316</td>
      <td>3.336793</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>19</td>
      <td>2.851927</td>
      <td>3.596783</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>20</td>
      <td>2.885449</td>
      <td>3.560956</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>21</td>
      <td>3.081255</td>
      <td>3.357426</td>
      <td>03:31</td>
    </tr>
    <tr>
      <td>22</td>
      <td>2.812135</td>
      <td>3.340290</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>23</td>
      <td>2.933871</td>
      <td>3.475993</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>24</td>
      <td>3.084240</td>
      <td>3.034758</td>
      <td>03:31</td>
    </tr>
    <tr>
      <td>25</td>
      <td>2.983608</td>
      <td>3.113349</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>26</td>
      <td>2.746827</td>
      <td>2.865806</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>27</td>
      <td>2.789029</td>
      <td>3.173259</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>28</td>
      <td>2.952777</td>
      <td>3.227012</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>29</td>
      <td>2.825185</td>
      <td>3.053979</td>
      <td>03:34</td>
    </tr>
    <tr>
      <td>30</td>
      <td>2.782907</td>
      <td>3.444182</td>
      <td>03:34</td>
    </tr>
    <tr>
      <td>31</td>
      <td>2.805190</td>
      <td>3.343132</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>32</td>
      <td>2.901620</td>
      <td>3.299375</td>
      <td>03:33</td>
    </tr>
    <tr>
      <td>33</td>
      <td>2.744463</td>
      <td>3.279421</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>34</td>
      <td>2.818238</td>
      <td>3.048206</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>35</td>
      <td>2.755671</td>
      <td>2.975504</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>36</td>
      <td>2.764382</td>
      <td>3.075425</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>37</td>
      <td>2.714343</td>
      <td>3.076662</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>38</td>
      <td>2.805259</td>
      <td>3.291719</td>
      <td>03:32</td>
    </tr>
    <tr>
      <td>39</td>
      <td>2.787018</td>
      <td>3.172551</td>
      <td>03:32</td>
    </tr>
  </tbody>
</table>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">get_data</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">192</span><span class="p">)</span>
<span class="n">learn</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">lr</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>gen_loss</th>
      <th>disc_loss</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>2.789968</td>
      <td>3.127500</td>
      <td>08:28</td>
    </tr>
    <tr>
      <td>1</td>
      <td>2.842687</td>
      <td>3.226334</td>
      <td>08:22</td>
    </tr>
    <tr>
      <td>2</td>
      <td>2.764777</td>
      <td>3.127393</td>
      <td>08:24</td>
    </tr>
    <tr>
      <td>3</td>
      <td>2.783910</td>
      <td>3.183345</td>
      <td>08:23</td>
    </tr>
    <tr>
      <td>4</td>
      <td>2.731649</td>
      <td>3.279976</td>
      <td>08:21</td>
    </tr>
    <tr>
      <td>5</td>
      <td>2.652934</td>
      <td>3.143363</td>
      <td>08:23</td>
    </tr>
    <tr>
      <td>6</td>
      <td>2.664248</td>
      <td>2.998718</td>
      <td>08:22</td>
    </tr>
    <tr>
      <td>7</td>
      <td>2.777635</td>
      <td>3.185632</td>
      <td>08:27</td>
    </tr>
    <tr>
      <td>8</td>
      <td>2.718668</td>
      <td>3.357025</td>
      <td>08:26</td>
    </tr>
    <tr>
      <td>9</td>
      <td>2.660009</td>
      <td>2.887908</td>
      <td>08:23</td>
    </tr>
  </tbody>
</table>

<p>The resulting training images looks like the following</p>

<center>
<img src="/images/gan-produced-image.png" style="zoom: 70%;" /><br />
<figcaption>GAN Produced Images</figcaption>
</center>

<p>And as you can see, our model was able to recolor the images to a certain extent of accuracy. This is not bad, but GANs do have their weaknesses which we’ll discuss in the last section. Before we wrap up the GAN section, let’s try to feed the model external images, that is images that it hasn’t seen before.</p>

<h3 id="recoloring-external-images">Recoloring External Images</h3>

<p>The following pet images were taken randomly from the internet. I’ve manually grayscaled the images and before letting the model predict its output.</p>

<center>
<img src="/images/gan-test-1.jpg" style="zoom: 70%;" /><br />
<figcaption>GAN Produced Images</figcaption>
</center>

<p>The colors produced, especially the animal’s fur is less saturated than it’s original image. However the natural background like grass and the sky is still acceptable, although different from the original.</p>

<p>Lastly, I tried to feed an image which is not a cat nor a dog. I tried to feed it images of actual people. The top row is a black-and-white picture which is already grayscaled when I received it. Whereas the bottom row’s image went through the same process as the images right above.</p>

<center>
<img src="/images/gan-test-2.jpg" style="zoom: 70%;" /><br />
<figcaption>GAN Produced Images</figcaption>
</center>

<p>Few things to notice here for the first prediction, the model is biased towards green and yellow colors, hence the floor color of the first output. Secondly, aside from coloring the person in front, the model also colored the person on the phone’s screen.</p>

<p>On the other hand, the second prediction was great at coloring the backdrop of mountains and the sky, but is bad at coloring the supposedly bright-red car as well as coloring the person as it remained mostly grey.</p>

<p>The most likely reason behind the poor recoloring of a person is because of the dataset being used to train the GAN on, which are Pets in this case.</p>

<h2 id="closing-remarks">Closing Remarks</h2>

<h3 id="weaknesses-of-gans">Weaknesses of GANs</h3>

<p>GANs are well known for being troublesome to be handled, especially during training, hence the fancy configuration and knobs which we have to have in order for it to behave well. Moreover, they take quite long hours to train in comparison to other architectures.</p>

<h3 id="possible-replacement-of-gans">Possible Replacement of GANs</h3>

<p>Just like shown in the remaining of Lecture 7, there are other architectures which are as good or even better than GANs, one of which is to use <strong>Feature Loss</strong> coupled with U-Nets, with shorter training hours and better results in several cases. I have tried doing that approach, but will not be discussing that here.</p>

<h3 id="conclusion">Conclusion</h3>

<p>GANs are great, the tasks they can do vary from one architecture to another, and is one of the methods to let a model “dream” and have their own forms of creativity. However, they have certain weaknesses which includes long training time and careful tweaking requirements. They are definitely modern, and doing reasearch in the domain is still very much open and fun to do if you’re into this particular field.</p>

<p>That’s it! Thanks for your time and I hope you’ve learned something!</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Generative Adversarial Network" /><summary type="html"><![CDATA[Fast.ai has a two-part Deep Learning Course, the first being Practical Deep Learning for Coders, and the second being Deep Learning from the Foundations, both having different approaches and intended for different audiences. In the 7th lecture of Part 1, Jeremy Howard taught a lot about modern architectures such as Residual Network (ResNet) , U-Net, and Generative Adversarial Network (GAN).]]></summary></entry><entry><title type="html">Handwritten Javanese Script Classification</title><link href="https://wilsonwongso.dev/posts/2020/07/handwritten-javanese-script-classifier/" rel="alternate" type="text/html" title="Handwritten Javanese Script Classification" /><published>2020-07-06T00:00:00+10:00</published><updated>2020-07-06T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/07/handwritten-javanese-script-classifier</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/07/handwritten-javanese-script-classifier/"><![CDATA[<p><em>Aksara Jawa</em>, or the <a href="https://en.wikipedia.org/wiki/Javanese_script">Javanese Script</a> is the core of writing the Javanese language and has influenced various other regional languages such as Sundanese, Madurese, etc. The script is now rarely used on a daily basis, but is sometimes taught in local schools in certain provinces of Indonesia.</p>

<h3 id="specific-form-of-aksara">Specific Form of Aksara</h3>

<p>The Javanese Script which we will be classifiying is specifically <a href="https://en.wikipedia.org/wiki/Javanese_script#Wyanjana">Aksara Wyanjana</a>’s <em>Nglegena</em>, or its basic characters. The list consists of 20 basic characters, without their respective <em>Pasangan</em> characters.</p>

<h3 id="dataset">Dataset</h3>

<p>Since I have not been able to find a handwritten Javanese Script dataset on the internet, I have decided to contact one of my English highschool teachers who has once showed my class her ability to write Javanese Script. The characters were written on paper, scanned, and edited manually. Credits to <strong>Mm. Martha Indrati</strong> for the help!</p>

<h3 id="image-classification">Image Classification</h3>

<p>This project is very much inspired from datasets like <a href="http://yann.lecun.com/exdb/mnist/">MNIST</a> and <a href="https://github.com/facebookresearch/qmnist">QMNIST</a> which are handwritten digits and is a go-to dataset for starting to learn image classification. The end goal of this project is to be able to create a deep learning model which will be able to classify handwritten Javanese Script to a certain degree of accuracy.</p>

<h2 id="code">Code</h2>

<p>The main framework to be used is fastai-v2, which sits on top of PyTorch. Fastai-v2 is still under development as of the time of this writing, but is ready to be used for basic image classification tasks.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">fastai2.vision.all</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">import</span> <span class="nn">torch</span>
</code></pre></div></div>

<h3 id="load-data">Load Data</h3>

<p>The data has been grouped per class folder, which we’ll load up and later split into training (70%) and validation (30%) images.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="s">"handwritten-javanese-script-dataset"</span><span class="p">)</span>
</code></pre></div></div>

<p>Notice we’re using a small batch size of 5, mainly because we only have 200 images in total.</p>

<p>Here we’ll apply cropping and resizing as transformations to our image since most of the characters do not fully occupy the image size. Additionally, we’ll resize to 128px.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dblock</span> <span class="o">=</span> <span class="n">DataBlock</span><span class="p">(</span><span class="n">blocks</span>     <span class="o">=</span> <span class="p">(</span><span class="n">ImageBlock</span><span class="p">(</span><span class="n">cls</span><span class="o">=</span><span class="n">PILImageBW</span><span class="p">),</span> <span class="n">CategoryBlock</span><span class="p">),</span>
                   <span class="n">get_items</span>  <span class="o">=</span> <span class="n">get_image_files</span><span class="p">,</span>
                   <span class="n">splitter</span>   <span class="o">=</span> <span class="n">GrandparentSplitter</span><span class="p">(</span><span class="n">valid_name</span><span class="o">=</span><span class="s">'val'</span><span class="p">),</span>
                   <span class="n">get_y</span>      <span class="o">=</span> <span class="n">parent_label</span><span class="p">,</span>
                   <span class="n">item_tfms</span>  <span class="o">=</span> <span class="p">[</span><span class="n">CropPad</span><span class="p">(</span><span class="mi">90</span><span class="p">),</span> <span class="n">Resize</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="n">ResizeMethod</span><span class="p">.</span><span class="n">Crop</span><span class="p">)])</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span> <span class="o">=</span> <span class="n">dblock</span><span class="p">.</span><span class="n">dataloaders</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span><span class="p">.</span><span class="n">show_batch</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_9_0.png" style="zoom: 70%;" />
</center>

<p>There are only 20 types of characters in the type of Aksara which we’ll be classifying.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dls</span><span class="p">.</span><span class="n">vocab</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(#20) ['ba','ca','da','dha','ga','ha','ja','ka','la','ma'...]
</code></pre></div></div>

<h3 id="model">Model</h3>

<p>We’ll be using <strong>XResNet50</strong> as the model, which is based on the <a href="https://openaccess.thecvf.com/content_CVPR_2019/papers/He_Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks_CVPR_2019_paper.pdf">Bag of Tricks paper</a> and is an “extension” to the <a href="https://arxiv.org/abs/1512.03385">ResNet50</a> architecture. We’ll pass our data, tell which metrics we’d like to observe, utilize <code class="language-plaintext highlighter-rouge">LabelSmoothingCrossEntropy</code>, and add <code class="language-plaintext highlighter-rouge">MixUp</code> as our callback.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span> <span class="o">=</span> <span class="n">Learner</span><span class="p">(</span><span class="n">dls</span><span class="p">,</span> <span class="n">xresnet50</span><span class="p">(</span><span class="n">c_in</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_out</span><span class="o">=</span><span class="n">dls</span><span class="p">.</span><span class="n">c</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">loss_func</span><span class="o">=</span><span class="n">LabelSmoothingCrossEntropy</span><span class="p">(),</span> <span class="n">cbs</span><span class="o">=</span><span class="n">MixUp</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="training-model">Training Model</h3>

<p>With all things in place, let’s finally train the model to learn from the given dataset and predict which class the image belongs to.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">lr_find</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>SuggestedLRs(lr_min=0.0003019951749593019, lr_steep=6.309573450380412e-07)
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_15_2.png" style="zoom: 70%;" />
</center>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">fit_one_cycle</span><span class="p">(</span><span class="mi">30</span><span class="p">,</span> <span class="mf">3e-4</span><span class="p">,</span> <span class="n">cbs</span><span class="o">=</span><span class="n">SaveModelCallback</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s">'accuracy'</span><span class="p">,</span> <span class="n">fname</span><span class="o">=</span><span class="s">'best_model'</span><span class="p">),</span> <span class="n">wd</span><span class="o">=</span><span class="mf">0.4</span><span class="p">)</span>
</code></pre></div></div>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: left;">
      <th>epoch</th>
      <th>train_loss</th>
      <th>valid_loss</th>
      <th>accuracy</th>
      <th>time</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0</td>
      <td>3.067268</td>
      <td>3.108827</td>
      <td>0.050000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>1</td>
      <td>2.929908</td>
      <td>2.669373</td>
      <td>0.333333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>2</td>
      <td>2.769148</td>
      <td>2.293764</td>
      <td>0.383333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>3</td>
      <td>2.588481</td>
      <td>2.215439</td>
      <td>0.316667</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>4</td>
      <td>2.416248</td>
      <td>2.324036</td>
      <td>0.283333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>5</td>
      <td>2.324458</td>
      <td>1.983255</td>
      <td>0.533333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>6</td>
      <td>2.189000</td>
      <td>2.105889</td>
      <td>0.383333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>7</td>
      <td>2.078479</td>
      <td>2.350886</td>
      <td>0.333333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>8</td>
      <td>1.922369</td>
      <td>2.823610</td>
      <td>0.216667</td>
      <td>00:05</td>
    </tr>
    <tr>
      <td>9</td>
      <td>1.790820</td>
      <td>1.584189</td>
      <td>0.650000</td>
      <td>00:05</td>
    </tr>
    <tr>
      <td>10</td>
      <td>1.683853</td>
      <td>1.509675</td>
      <td>0.583333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>11</td>
      <td>1.598790</td>
      <td>1.570487</td>
      <td>0.650000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>12</td>
      <td>1.528586</td>
      <td>1.256149</td>
      <td>0.833333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>13</td>
      <td>1.484508</td>
      <td>1.623523</td>
      <td>0.566667</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>14</td>
      <td>1.437240</td>
      <td>1.340925</td>
      <td>0.750000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>15</td>
      <td>1.345987</td>
      <td>1.138785</td>
      <td>0.816667</td>
      <td>00:05</td>
    </tr>
    <tr>
      <td>16</td>
      <td>1.350891</td>
      <td>1.370259</td>
      <td>0.716667</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>17</td>
      <td>1.297572</td>
      <td>1.453033</td>
      <td>0.666667</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>18</td>
      <td>1.318248</td>
      <td>1.330522</td>
      <td>0.750000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>19</td>
      <td>1.263931</td>
      <td>1.023822</td>
      <td>0.900000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>20</td>
      <td>1.247242</td>
      <td>1.063768</td>
      <td>0.900000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>21</td>
      <td>1.234829</td>
      <td>1.009032</td>
      <td>0.933333</td>
      <td>00:05</td>
    </tr>
    <tr>
      <td>22</td>
      <td>1.203268</td>
      <td>0.968369</td>
      <td>0.950000</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>23</td>
      <td>1.178766</td>
      <td>0.965601</td>
      <td>0.916667</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>24</td>
      <td>1.156069</td>
      <td>0.939599</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>25</td>
      <td>1.183693</td>
      <td>0.943586</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>26</td>
      <td>1.166053</td>
      <td>0.933629</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>27</td>
      <td>1.162939</td>
      <td>0.936014</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>28</td>
      <td>1.132883</td>
      <td>0.936722</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
    <tr>
      <td>29</td>
      <td>1.138776</td>
      <td>0.946842</td>
      <td>0.933333</td>
      <td>00:04</td>
    </tr>
  </tbody>
</table>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Better model found at epoch 0 with accuracy value: 0.05000000074505806.
Better model found at epoch 1 with accuracy value: 0.3333333432674408.
Better model found at epoch 2 with accuracy value: 0.38333332538604736.
Better model found at epoch 5 with accuracy value: 0.5333333611488342.
Better model found at epoch 9 with accuracy value: 0.6499999761581421.
Better model found at epoch 12 with accuracy value: 0.8333333134651184.
Better model found at epoch 19 with accuracy value: 0.8999999761581421.
Better model found at epoch 21 with accuracy value: 0.9333333373069763.
Better model found at epoch 22 with accuracy value: 0.949999988079071.
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">recorder</span><span class="p">.</span><span class="n">plot_loss</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_17_0.png" style="zoom: 70%;" />
</center>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="s">'stage-1'</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="analyze-results">Analyze Results</h3>

<p>After training, let’s see how well our model learned. Any incorrect prediction in a random batch will have its label colored red.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">show_results</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_20_1.png" style="zoom: 70%;" />
</center>

<p>Instead of only viewing a batch, let’s analyze the results from the entire validation dataset.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">interp</span> <span class="o">=</span>  <span class="n">ClassificationInterpretation</span><span class="p">.</span><span class="n">from_learner</span><span class="p">(</span><span class="n">learn</span><span class="p">)</span>
</code></pre></div></div>

<p>This confusion matrix lists all the actual versus predicted labels. The darker the blue on the diagonal line, the better our model is at predicting.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">interp</span><span class="p">.</span><span class="n">plot_confusion_matrix</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">60</span><span class="p">)</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_24_0.png" style="zoom: 70%;" />
</center>

<p>On the other hand, this type of interpretation shows several of the predicted images, what our model thinks it is, and how confident it is with that prediction.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">interp</span><span class="p">.</span><span class="n">plot_top_losses</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span><span class="mi">9</span><span class="p">))</span>
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_26_0.png" style="zoom: 70%;" />
</center>

<h3 id="predicting-external-images">Predicting External Images</h3>

<p>To see how our model’s regularization fairs, let’s attempt to feed it an external data and see what it predicted.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">open_image_bw_resize</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">PILImageBW</span><span class="p">:</span>
    <span class="k">return</span> <span class="n">PILImageBW</span><span class="p">(</span><span class="n">Image</span><span class="p">.</span><span class="nb">open</span><span class="p">(</span><span class="n">source</span><span class="p">).</span><span class="n">resize</span><span class="p">((</span><span class="mi">128</span><span class="p">,</span><span class="mi">128</span><span class="p">)).</span><span class="n">convert</span><span class="p">(</span><span class="s">'L'</span><span class="p">))</span>
</code></pre></div></div>

<p>The following character is supposed to be <strong>ma</strong> and was picked randomly from available images on the internet.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test0</span> <span class="o">=</span> <span class="n">open_image_bw_resize</span><span class="p">(</span><span class="s">'test-image-0.jpg'</span><span class="p">)</span>
<span class="n">test0</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;matplotlib.axes._subplots.AxesSubplot at 0x1e8960ffaf0&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_31_1.png" style="zoom: 70%;" />
</center>

<p>Feed it through the model and see its output.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'ma'
</code></pre></div></div>

<p>Luckily, the model was able to predict the character correctly. To challenge the model even more, I tried to write Javanese Script characters myself and see what the model predicts. Do note that I do not have any background in writing Javanese Scripts, so pardon my skills.</p>

<p>The following character is supposed to be <strong>ca</strong>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test1</span> <span class="o">=</span> <span class="n">open_image_bw_resize</span><span class="p">(</span><span class="s">'test-image-1.jpg'</span><span class="p">)</span>
<span class="n">test1</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;matplotlib.axes._subplots.AxesSubplot at 0x1e895ef6610&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_35_1.png" style="zoom: 70%;" />
</center>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'ca'
</code></pre></div></div>

<p>This character is supposed to be <strong>wa</strong>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">test2</span> <span class="o">=</span> <span class="n">open_image_bw_resize</span><span class="p">(</span><span class="s">'test-image-2.jpg'</span><span class="p">)</span>
<span class="n">test2</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;matplotlib.axes._subplots.AxesSubplot at 0x1e8c2a21580&gt;
</code></pre></div></div>

<center>
<img src="/images/2020/07/handwritten-javanese-script-classifier/output_38_1.png" style="zoom: 70%;" />
</center>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">learn</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test2</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>'ca'
</code></pre></div></div>

<p>Well that’s an incorrect guess, which is reasonable firstly because of my poor handwriting skills, and secondly the model was trained on a person’s particular style of handwriting - which in this case is my teacher’s. There could be many other factors which caused the incorrect guess, such as overfitting by the model, small dataset and possibly more.</p>

<h2 id="closing-remarks">Closing Remarks</h2>

<p>There are several possible improvements which could be made, one of which is to increase the variety and the size of the dataset, since the model is only training on a single person’s handwriting. It’ll be better in terms of regularization to add other people’s handwriting into the mix as well.</p>

<p>That’s it for this mini project of mine. Thanks for your time and I hope you’ve learned something!</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Convolutional Neural Network" /><summary type="html"><![CDATA[Aksara Jawa, or the Javanese Script is the core of writing the Javanese language and has influenced various other regional languages such as Sundanese, Madurese, etc. The script is now rarely used on a daily basis, but is sometimes taught in local schools in certain provinces of Indonesia.]]></summary></entry><entry><title type="html">Hash Tables, Collisions, and Separate Chaining</title><link href="https://wilsonwongso.dev/posts/2020/04/doubly-linked-list-c/" rel="alternate" type="text/html" title="Hash Tables, Collisions, and Separate Chaining" /><published>2020-04-11T00:00:00+10:00</published><updated>2020-04-11T00:00:00+10:00</updated><id>https://wilsonwongso.dev/posts/2020/04/hash-table-c</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/04/doubly-linked-list-c/"><![CDATA[<p>According to <a href="https://en.wikipedia.org/wiki/Hash_table">Wikipedia</a>, a <strong>hash table</strong> or sometimes called <strong>hash map</strong> is a a data structure that implements an associative array abstract data type, a structure that can map keys to values.</p>

<p>Unlike Linked Lists, Hash Tables allow for an O(1) time complexity when searching, which is a powerful tool knowing that a Linked List requires O(n) complexity.
However, there may be possible collisions when inserting a new data whose key has already been used. This causes the search time complexity to have a O(n) worst case, just like a Linked List.</p>

<p>We’ll implement a Hash Table using the C language.
For each key in the hash table, we’ll implement it using a <code class="language-plaintext highlighter-rouge">node</code> of a Singly Linked List to cater for separate chaining.</p>

<p>The complete code for this post can be found <a href="https://github.com/w11wo/wilsonwongso.dev/blob/master/files/code/hashTable.c">here</a>.</p>

<h2 id="header-files">Header Files</h2>

<p>The only header files we’ll be using are the following</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;stdio.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;stdlib.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;string.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;ctype.h&gt;</span><span class="cp">
</span></code></pre></div></div>

<h2 id="hash-table-size">Hash Table Size</h2>

<p>To keep the size of the hash table constant, we will define the maximum number of keys using the <code class="language-plaintext highlighter-rouge">define</code> keyword. There will be only 26 keys which is based on the 26 alphabets.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#define MAX_N 26
</span></code></pre></div></div>

<h2 id="node-struct">Node Struct</h2>

<p>Each node in the linked list only consists a <code class="language-plaintext highlighter-rouge">name</code> <code class="language-plaintext highlighter-rouge">string</code> and a <code class="language-plaintext highlighter-rouge">pointer</code> <code class="language-plaintext highlighter-rouge">next</code> which points to the next node in the linked list.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">typedef</span> <span class="k">struct</span> <span class="n">node</span> <span class="p">{</span>
    <span class="kt">char</span> <span class="n">name</span><span class="p">[</span><span class="mi">200</span><span class="p">];</span>
    <span class="k">struct</span> <span class="n">node</span><span class="o">*</span> <span class="n">next</span><span class="p">;</span>
<span class="p">}</span> <span class="n">node</span><span class="p">;</span>
</code></pre></div></div>

<h2 id="function-prototypes">Function Prototypes</h2>

<p>Since we are writing in C, we need to first prototype every function we’re going to implement below our <code class="language-plaintext highlighter-rouge">main</code> function. Here is the list of functions we’ll be implementing</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">node</span><span class="o">*</span> <span class="nf">create_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">);</span>
<span class="kt">int</span> <span class="nf">hash</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">insert</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[],</span> <span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">);</span>
<span class="kt">char</span><span class="o">*</span> <span class="nf">search</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[],</span> <span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">print_list</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">head</span><span class="p">,</span> <span class="kt">int</span> <span class="n">idx</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">print_table</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[]);</span>
</code></pre></div></div>

<h2 id="creating-a-node">Creating a Node</h2>

<p>To create our <code class="language-plaintext highlighter-rouge">node</code>, we will implement the following function. It is almost identical to the one in the Singly Linked List <a href="https://wilsonwongso.dev/blog/code/c/2020/02/28/singly-linked-list-c.html">blog post</a>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">node</span><span class="o">*</span> <span class="nf">create_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">node</span><span class="o">*</span> <span class="n">student</span> <span class="o">=</span> <span class="p">(</span><span class="n">node</span><span class="o">*</span><span class="p">)</span> <span class="n">malloc</span><span class="p">(</span><span class="k">sizeof</span><span class="p">(</span><span class="n">node</span><span class="p">));</span>

    <span class="n">strcpy</span><span class="p">(</span><span class="n">student</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">name</span><span class="p">);</span>
    <span class="n">student</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>

    <span class="k">return</span> <span class="n">student</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="hashing-function">Hashing Function</h2>

<p>The hashing function we’ll be using is called the division method.
Specifically, we take the first letter of the string, convert it to its lower case letter, and return its ASCII equivalent.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="nf">hash</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">return</span> <span class="n">tolower</span><span class="p">(</span><span class="n">name</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">-</span> <span class="sc">'a'</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<p>With that we will get only 26 possible keys for a string, which is the same as the <code class="language-plaintext highlighter-rouge">MAX_N</code> we’ve defined earlier.
There are various ways to create a better hashing function with aims to reduce collisions, which we’ll discuss later in the blog.</p>

<h2 id="inserting-a-node-into-the-hash-table">Inserting a Node into the Hash Table</h2>

<p>First we utilize <code class="language-plaintext highlighter-rouge">create_node()</code> to allocate the memory required and create the <code class="language-plaintext highlighter-rouge">student</code> <code class="language-plaintext highlighter-rouge">node</code>.
Then, with the hashing function we can get the corresponding key for the given <code class="language-plaintext highlighter-rouge">string</code> <code class="language-plaintext highlighter-rouge">name</code>.</p>

<p>Inserting the first node of a key is as simple as getting the address of the corresponding <code class="language-plaintext highlighter-rouge">head</code> and setting it to be the newly created <code class="language-plaintext highlighter-rouge">student</code>.</p>

<p>However, we will need to address possible collisions and we do this by a method called separate chaining, which basically means appending the next <code class="language-plaintext highlighter-rouge">node</code> with the same key to the linked list.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">insert</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[],</span> <span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">node</span><span class="o">*</span> <span class="n">student</span> <span class="o">=</span> <span class="n">create_node</span><span class="p">(</span><span class="n">name</span><span class="p">);</span>

    <span class="kt">int</span> <span class="n">key</span> <span class="o">=</span> <span class="n">hash</span><span class="p">(</span><span class="n">name</span><span class="p">);</span>

    <span class="n">node</span><span class="o">**</span> <span class="n">head</span> <span class="o">=</span> <span class="o">&amp;</span><span class="n">root</span><span class="p">[</span><span class="n">key</span><span class="p">];</span>

    <span class="k">if</span> <span class="p">(</span><span class="o">*</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if the head of a particular key is still NULL;</span>
        <span class="o">*</span><span class="n">head</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span> <span class="c1">// separate chaining, i.e. push the new node to the back of the linked list.</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="o">*</span><span class="n">head</span><span class="p">;</span>
        <span class="k">while</span><span class="p">(</span><span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
        <span class="p">}</span>
        <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="searching-for-a-name-inside-the-hash-table">Searching for a name inside the Hash Table.</h2>

<p>Searching is one of the most powerful features of a Hash Table as discussed previously.
To implement searching by <code class="language-plaintext highlighter-rouge">name</code>, we will be using linear search since we are using linked lists.</p>

<p>Just like insertion, we cater to the possible scenarios like a <code class="language-plaintext highlighter-rouge">NULL</code> <code class="language-plaintext highlighter-rouge">head</code> and traversing through a linked list otherwise.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">char</span><span class="o">*</span> <span class="nf">search</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[],</span> <span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">key</span> <span class="o">=</span> <span class="n">hash</span><span class="p">(</span><span class="n">name</span><span class="p">);</span>

    <span class="n">node</span><span class="o">*</span> <span class="n">head</span> <span class="o">=</span> <span class="n">root</span><span class="p">[</span><span class="n">key</span><span class="p">];</span>

    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">return</span> <span class="nb">NULL</span><span class="p">;</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
        <span class="k">while</span><span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
            <span class="k">if</span> <span class="p">(</span><span class="n">strcmp</span><span class="p">(</span><span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
                <span class="k">return</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">;</span>
            <span class="p">}</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
        <span class="p">}</span>
        <span class="k">return</span> <span class="nb">NULL</span><span class="p">;</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="printing-a-linked-list">Printing a Linked List</h2>

<p>Since each key in the hash table uses a linked list, we need to prepare a function which prints a linked list.
We would also like to print the index of the linked list in the hash table, so we pass another parameter called <code class="language-plaintext highlighter-rouge">idx</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">print_list</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">head</span><span class="p">,</span> <span class="kt">int</span> <span class="n">idx</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">return</span><span class="p">;</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"[%d] "</span><span class="p">,</span> <span class="n">idx</span><span class="p">);</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
        <span class="k">while</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">printf</span><span class="p">(</span><span class="s">"%s"</span><span class="p">,</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">);</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
            <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
                <span class="n">printf</span><span class="p">(</span><span class="s">" -&gt; "</span><span class="p">);</span>
            <span class="p">}</span>
        <span class="p">}</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">);</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="printing-the-hash-table">Printing the Hash Table</h2>

<p>Finally, we can implement a function to print every non-empty linked lists and its contents.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">print_table</span><span class="p">(</span><span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[])</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">MAX_N</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">print_list</span><span class="p">(</span><span class="n">root</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">i</span><span class="p">);</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="main-function">Main Function</h2>

<p>We’ll demonstrate how the <code class="language-plaintext highlighter-rouge">main</code> function looks like.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="nf">main</span><span class="p">(</span><span class="kt">void</span><span class="p">)</span> <span class="p">{</span>

    <span class="n">node</span><span class="o">*</span> <span class="n">root</span><span class="p">[</span><span class="n">MAX_N</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="nb">NULL</span><span class="p">};</span>

    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Apple"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Orange"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Papaya"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Avocado"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Blueberry"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Peach"</span><span class="p">);</span>
    <span class="n">insert</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Plum"</span><span class="p">);</span>

    <span class="kt">char</span><span class="o">*</span> <span class="n">find_banana</span> <span class="o">=</span> <span class="n">search</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Banana"</span><span class="p">);</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">find_banana</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"%s found</span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">find_banana</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="kt">char</span><span class="o">*</span> <span class="n">find_avocado</span> <span class="o">=</span> <span class="n">search</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="s">"Avocado"</span><span class="p">);</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">find_avocado</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"%s found</span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">find_avocado</span><span class="p">);</span>
    <span class="p">}</span>

    <span class="n">print_table</span><span class="p">(</span><span class="n">root</span><span class="p">);</span>

    <span class="k">return</span> <span class="mi">0</span><span class="p">;</span>

<span class="p">}</span>
</code></pre></div></div>

<p>The output looks something like this</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Avocado found
[0] Apple -&gt; Avocado
[1] Blueberry
[14] Orange
[15] Papaya -&gt; Peach -&gt; Plum
</code></pre></div></div>

<h2 id="a-better-hash-table-and-hashing-function">A Better Hash Table and Hashing Function</h2>

<p>As said previously, the previously implemented function is not the best when trying to reduce collisions. We would like to avoid collisions as much as possible to maintain O(1) search time.</p>

<p>A reference code was given by my lecturer as an example of a good hashing function and hash table, as well as explanations as to why certain decisions were made.</p>

<p><strong>Credit belongs to the author of this code</strong>.</p>

<p>There are several aspects to improve a hash table, namely its size and the hashing function being used.</p>

<h3 id="size-of-hash-table">Size of Hash Table</h3>

<p>Firstly, the size of the hash table plays a role in the distribution of the key in the hash table.</p>

<p>A good size would be a prime number, since it has very few factors. While non-prime numbers cause distribution of keys to be not uniformly distributed.</p>

<p>Simply put, a non-uniform distribution of keys causes other keys which are not factors of the size of the hash table to be of high probability in being empty.</p>

<p>For example, if our choice was to use 12 as the size of the hash table.
The key 3, a factor of 12, along with its multiples (0, 3, 6, 9, …) will be more likely to be filled while others empty, thus increasing the chance of collision.</p>

<h3 id="hash-function">Hash Function</h3>

<p>Aside from being fast to be computed, a good hashing function distributes keys as uniformly possible.</p>

<p>To do so, we sum the ASCII equivalents of every character in the string to make the key as unique as possible.</p>

<p>We also add a so-called zero-padding if ever an empty string is allowed to prevent it affecting universality.</p>

<p>In addition, every time we sum the ASCII, we add a base number, which is strictly greater than the number of different values of each individual letters. Doing so further increases the range of the possible keys hence reducing collisions.</p>

<p>For example, since there are 26 possible lowercase letters, a base number like 31 is preferable. The base number 31 is also used by a method called <code class="language-plaintext highlighter-rouge">hashCode()</code> in Java’s <code class="language-plaintext highlighter-rouge">String</code> <a href="https://docs.oracle.com/javase/6/docs/api/java/lang/String.html#hashCode()">class</a>.</p>

<p>A more detailed explanation can be found in this Wikipedia page about <a href="https://en.wikipedia.org/wiki/Universal_hashing#Hashing_strings">Universal hashing</a>.</p>

<h3 id="code-implementation">Code Implementation</h3>

<p>With the said changes, our hash table starts off to be of size 97.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="n">hashTable</span><span class="p">[</span><span class="mi">97</span><span class="p">];</span>
</code></pre></div></div>

<p>While the hash function looks like</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="nf">hash</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span> <span class="o">*</span><span class="n">str</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">len</span> <span class="o">=</span> <span class="n">strlen</span><span class="p">(</span><span class="n">str</span><span class="p">);</span>

    <span class="kt">int</span> <span class="n">base</span> <span class="o">=</span> <span class="mi">31</span><span class="p">;</span>

    <span class="kt">int</span> <span class="n">MODPRIME</span> <span class="o">=</span> <span class="mi">97</span><span class="p">;</span>

    <span class="kt">int</span> <span class="n">ret</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>

    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">len</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="p">(</span><span class="n">ret</span> <span class="o">*</span> <span class="n">base</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">str</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="sc">'a'</span> <span class="o">+</span> <span class="mi">1</span><span class="p">);</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span> <span class="o">%</span> <span class="n">MODPRIME</span><span class="p">;</span>
    <span class="p">}</span>

    <span class="k">return</span> <span class="p">(</span><span class="n">ret</span> <span class="o">*</span> <span class="n">base</span><span class="p">)</span> <span class="o">%</span> <span class="n">MODPRIME</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>It is very interesting to see how small details of a hash table can greatly affect its performance and the math behind it.
With the capabilities of a hash table, searching can be greatly improved in comparison to the previously discussed data structures.</p>

<p>To read up more on hash tables, the reference code also linked to very resourceful <a href="http://cseweb.ucsd.edu/~kube/cls/100/Lectures/lec16/lec16-2.html#pgfId-982677">notes</a> from UC San Diego.</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Data Structures" /><summary type="html"><![CDATA[According to Wikipedia, a hash table or sometimes called hash map is a a data structure that implements an associative array abstract data type, a structure that can map keys to values.]]></summary></entry><entry><title type="html">Doubly Linked List in C</title><link href="https://wilsonwongso.dev/posts/2020/03/doubly-linked-list-c/" rel="alternate" type="text/html" title="Doubly Linked List in C" /><published>2020-03-05T00:00:00+11:00</published><updated>2020-03-05T00:00:00+11:00</updated><id>https://wilsonwongso.dev/posts/2020/03/doubly-linked-list-c</id><content type="html" xml:base="https://wilsonwongso.dev/posts/2020/03/doubly-linked-list-c/"><![CDATA[<p>After learning how to implement Singly Linked List, we’re going to implement Doubly Linked List, which is similar to Singly Linked List, but with the addition of a <code class="language-plaintext highlighter-rouge">prev</code> <code class="language-plaintext highlighter-rouge">pointer</code> which points to the node before it.</p>

<p>We’ll implement a Doubly Linked List using the C language. The complete code for this post can be found <a href="https://github.com/w11wo/wilsonwongso.dev/blob/master/files/code/doublyLinkedList.c">here</a>.</p>

<p>The following code is <strong>based</strong> on a lecture by Rhio Sutoyo, S.Kom., M.Sc. in <strong>Data Structures</strong> course.</p>

<h2 id="header-files">Header Files</h2>

<p>The only header files we’ll be using are the following</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="cp">#include</span> <span class="cpf">&lt;stdio.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;stdlib.h&gt;</span><span class="cp">
#include</span> <span class="cpf">&lt;string.h&gt;</span><span class="cp">
</span></code></pre></div></div>

<h2 id="node-struct">Node Struct</h2>

<p>A <code class="language-plaintext highlighter-rouge">node</code> is just a single element inside the list, which in this case represents a student’s information with their <code class="language-plaintext highlighter-rouge">name</code> and <code class="language-plaintext highlighter-rouge">gpa</code>. Also, it has a <code class="language-plaintext highlighter-rouge">pointer</code> to the next and previous <code class="language-plaintext highlighter-rouge">node</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">typedef</span> <span class="k">struct</span> <span class="n">node</span> <span class="p">{</span>
    <span class="kt">char</span> <span class="n">name</span><span class="p">[</span><span class="mi">200</span><span class="p">];</span>
    <span class="kt">double</span> <span class="n">gpa</span><span class="p">;</span>
    <span class="k">struct</span> <span class="n">node</span><span class="o">*</span> <span class="n">next</span><span class="p">;</span>
    <span class="k">struct</span> <span class="n">node</span><span class="o">*</span> <span class="n">prev</span><span class="p">;</span>
<span class="p">}</span> <span class="n">node</span><span class="p">;</span>
</code></pre></div></div>

<p>Notice that we also use <code class="language-plaintext highlighter-rouge">typedef</code> which allows us to omit the <code class="language-plaintext highlighter-rouge">struct</code> keyword in the instantiation of a <code class="language-plaintext highlighter-rouge">node</code>.</p>

<h2 id="function-prototypes">Function Prototypes</h2>

<p>Since we are writing in C, we need to first prototype every function we’re going to implement below our <code class="language-plaintext highlighter-rouge">main</code> function. Here is the list of functions we’ll be implementing</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">node</span><span class="o">*</span> <span class="nf">create_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">,</span> <span class="kt">double</span> <span class="n">gpa</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">sorted_push</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">,</span> <span class="kt">double</span> <span class="n">gpa</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">delete_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">key</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">print_list</span><span class="p">(</span><span class="kt">void</span><span class="p">);</span>
<span class="kt">void</span> <span class="nf">print_reversed_list</span><span class="p">(</span><span class="kt">void</span><span class="p">);</span>
</code></pre></div></div>

<h2 id="global-head-and-tail">Global Head and Tail</h2>

<p>For this example, we will create a global variable called <code class="language-plaintext highlighter-rouge">head</code> and <code class="language-plaintext highlighter-rouge">tail</code>, which denotes the first element and the last element in the list respectively.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">node</span> <span class="o">*</span><span class="n">head</span><span class="p">,</span> <span class="o">*</span><span class="n">tail</span><span class="p">;</span>
</code></pre></div></div>

<h2 id="creating-a-node">Creating a Node</h2>

<p>To create our student <code class="language-plaintext highlighter-rouge">node</code>, we will implement the following function.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">node</span><span class="o">*</span> <span class="nf">create_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">,</span> <span class="kt">double</span> <span class="n">gpa</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// allocate memory of size 'node';</span>
    <span class="n">node</span><span class="o">*</span> <span class="n">student</span> <span class="o">=</span> <span class="p">(</span><span class="n">node</span><span class="o">*</span><span class="p">)</span> <span class="n">malloc</span><span class="p">(</span><span class="k">sizeof</span><span class="p">(</span><span class="n">node</span><span class="p">));</span>
    <span class="c1">// create a new node based on the given arguments;</span>
    <span class="n">strcpy</span><span class="p">(</span><span class="n">student</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">name</span><span class="p">);</span>
    <span class="n">student</span><span class="o">-&gt;</span><span class="n">gpa</span> <span class="o">=</span> <span class="n">gpa</span><span class="p">;</span>

    <span class="k">return</span> <span class="n">student</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="sorted-push">Sorted Push</h2>

<p>Instead of implementing push front, back, or middle, we’re going to create a function which will automatically insert a node in ascending order of <code class="language-plaintext highlighter-rouge">gpa</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">sorted_push</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">name</span><span class="p">,</span> <span class="kt">double</span> <span class="n">gpa</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// create a new node;</span>
    <span class="n">node</span><span class="o">*</span> <span class="n">student</span> <span class="o">=</span> <span class="n">create_node</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">gpa</span><span class="p">);</span>

    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if list is empty;</span>
        <span class="n">head</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
        <span class="n">tail</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
        <span class="n">head</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
        <span class="n">tail</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
        <span class="n">head</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
        <span class="n">tail</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
        <span class="c1">// traverse to the node with gpa greater than the one being pushed;</span>
        <span class="k">while</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span> <span class="o">&amp;&amp;</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">gpa</span> <span class="o">&lt;</span> <span class="n">student</span><span class="o">-&gt;</span><span class="n">gpa</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
        <span class="p">}</span>

        <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="n">head</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if the head already has a value greater than the new node's;</span>
            <span class="c1">// append old head to the new node;</span>
            <span class="n">student</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
            <span class="n">head</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
            <span class="c1">// set new node as new head;</span>
            <span class="n">head</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
            <span class="n">head</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
        <span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if we've reached the node after tail, i.e. all values are less than the value being pushed;</span>
            <span class="c1">// append new node to tail;</span>
            <span class="n">tail</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
            <span class="n">student</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="n">tail</span><span class="p">;</span>
            <span class="c1">// set new node as new tail;</span>
            <span class="n">tail</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
            <span class="n">tail</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
            <span class="n">free</span><span class="p">(</span><span class="n">curr</span><span class="p">);</span>
        <span class="p">}</span> <span class="k">else</span> <span class="p">{</span> <span class="c1">// if we have to push the new node in the middle;</span>
            <span class="c1">// connect the current's previous node to the new node;</span>
            <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
            <span class="n">student</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span><span class="p">;</span>
            <span class="c1">// connect curr as the next of the new node;</span>
            <span class="n">student</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">curr</span><span class="p">;</span>
            <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="n">student</span><span class="p">;</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="delete-a-node-based-on-name">Delete a Node Based on Name</h2>

<p>We can delete a particular <code class="language-plaintext highlighter-rouge">node</code> based on its <code class="language-plaintext highlighter-rouge">name</code>.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">delete_node</span><span class="p">(</span><span class="k">const</span> <span class="kt">char</span><span class="o">*</span> <span class="n">key</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if list is empty;</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"List is empty.</span><span class="se">\n</span><span class="s">"</span><span class="p">);</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
        <span class="c1">// traverse to the node to be deleted;</span>
        <span class="k">while</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span> <span class="o">&amp;&amp;</span> <span class="n">strcmp</span><span class="p">(</span><span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
        <span class="p">}</span>

        <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if key is not in the list;</span>
            <span class="n">printf</span><span class="p">(</span><span class="s">"</span><span class="se">\"</span><span class="s">%s</span><span class="se">\"</span><span class="s"> is not in the list.</span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">key</span><span class="p">);</span>
        <span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="n">head</span> <span class="o">&amp;&amp;</span> <span class="n">curr</span> <span class="o">==</span> <span class="n">tail</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if key the only node in the list;</span>
            <span class="c1">// delete node;</span>
            <span class="n">free</span><span class="p">(</span><span class="n">curr</span><span class="p">);</span>
            <span class="c1">// reset head and tail;</span>
            <span class="n">head</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
            <span class="n">tail</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
        <span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="n">head</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if key is head;</span>
            <span class="c1">// set old head's next as new head;</span>
            <span class="n">head</span> <span class="o">=</span> <span class="n">head</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
            <span class="n">head</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
            <span class="c1">// free old head since its no longer used;</span>
            <span class="n">free</span><span class="p">(</span><span class="n">curr</span><span class="p">);</span>
        <span class="p">}</span> <span class="k">else</span> <span class="k">if</span> <span class="p">(</span><span class="n">curr</span> <span class="o">==</span> <span class="n">tail</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if key is tail;</span>
            <span class="c1">// set the node before old tail to be the new tail;</span>
            <span class="n">tail</span> <span class="o">=</span> <span class="n">tail</span><span class="o">-&gt;</span><span class="n">prev</span><span class="p">;</span>
            <span class="n">tail</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="nb">NULL</span><span class="p">;</span>
            <span class="c1">// fre old tail;</span>
            <span class="n">free</span><span class="p">(</span><span class="n">curr</span><span class="p">);</span>
        <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
            <span class="c1">// skip the node being deleted;</span>
            <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span><span class="o">-&gt;</span><span class="n">next</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
            <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="o">-&gt;</span><span class="n">prev</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span><span class="p">;</span>
            <span class="c1">// free the deleted node;</span>
            <span class="n">free</span><span class="p">(</span><span class="n">curr</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="print-linked-list">Print Linked List</h2>

<p>For convenience, create a function to <code class="language-plaintext highlighter-rouge">print</code> the entire list.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">print_list</span><span class="p">(</span><span class="kt">void</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if list is empty;</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"List is empty.</span><span class="se">\n</span><span class="s">"</span><span class="p">);</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">head</span><span class="p">;</span>
        <span class="c1">// traverse through each node;</span>
        <span class="k">while</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">printf</span><span class="p">(</span><span class="s">"Name: %-10s GPA: %.2lf</span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">gpa</span><span class="p">);</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">next</span><span class="p">;</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="print-linked-list-in-reverse-order">Print Linked List in Reverse Order</h2>

<p>Since we have the <code class="language-plaintext highlighter-rouge">prev</code> <code class="language-plaintext highlighter-rouge">pointer</code>, we can easily <code class="language-plaintext highlighter-rouge">print</code> the list in reverse order.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">void</span> <span class="nf">print_reversed_list</span><span class="p">(</span><span class="kt">void</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">head</span> <span class="o">==</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span> <span class="c1">// if list is empty;</span>
        <span class="n">printf</span><span class="p">(</span><span class="s">"List is empty.</span><span class="se">\n</span><span class="s">"</span><span class="p">);</span>
    <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">node</span><span class="o">*</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">tail</span><span class="p">;</span>
        <span class="c1">// traverse through each node backwards;</span>
        <span class="k">while</span> <span class="p">(</span><span class="n">curr</span> <span class="o">!=</span> <span class="nb">NULL</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">printf</span><span class="p">(</span><span class="s">"Name: %-10s GPA: %.2lf</span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">name</span><span class="p">,</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">gpa</span><span class="p">);</span>
            <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="o">-&gt;</span><span class="n">prev</span><span class="p">;</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<h2 id="main-function">Main Function</h2>

<p>Lastly, we’ll demonstrate how the <code class="language-plaintext highlighter-rouge">main</code> function looks like.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="nf">main</span><span class="p">(</span><span class="kt">void</span><span class="p">)</span> <span class="p">{</span>

    <span class="n">sorted_push</span><span class="p">(</span><span class="s">"Steven"</span><span class="p">,</span> <span class="mi">3</span><span class="p">.</span><span class="mi">5</span><span class="p">);</span> <span class="c1">// [Steven]</span>
    <span class="n">sorted_push</span><span class="p">(</span><span class="s">"Bill"</span><span class="p">,</span> <span class="mi">2</span><span class="p">.</span><span class="mi">0</span><span class="p">);</span> <span class="c1">// [Bill, Steven]</span>
    <span class="n">sorted_push</span><span class="p">(</span><span class="s">"John"</span><span class="p">,</span> <span class="mi">3</span><span class="p">.</span><span class="mi">7</span><span class="p">);</span> <span class="c1">// [Bill, Steven, John]</span>
    <span class="n">sorted_push</span><span class="p">(</span><span class="s">"Ace"</span><span class="p">,</span> <span class="mi">2</span><span class="p">.</span><span class="mi">5</span><span class="p">);</span> <span class="c1">// [Bill, Ace, Steven, John]</span>

    <span class="n">delete_node</span><span class="p">(</span><span class="s">"Ace"</span><span class="p">);</span> <span class="c1">// [Bill, Steven, John]</span>

    <span class="n">print_list</span><span class="p">();</span>

    <span class="k">return</span> <span class="mi">0</span><span class="p">;</span>

<span class="p">}</span>
</code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>With doubly linked list, we can easily move forward and backward from a node, which will highly ease the process of adding a node, printing in reverse order, and others which singly linked list would have a difficulty of doing.</p>]]></content><author><name>Wilson Wongso</name><email>wilsonwong961@gmail.com</email></author><category term="Data Structures" /><summary type="html"><![CDATA[After learning how to implement Singly Linked List, we’re going to implement Doubly Linked List, which is similar to Singly Linked List, but with the addition of a prev pointer which points to the node before it.]]></summary></entry></feed>