import React from "react";

function ProjGrok() {
    return (
        <>
            <h2>Safety and Interpretability: Progress measures for grokking tickets</h2>
            [<a target="_blank" href="https://github.com/nicholaschenai/grokking-tickets-lambda">Code</a>]
            [<a target="_blank" href="https://github.com/nicholaschenai/interp-notes">Lit Review</a>]
            [<a target="_blank" href="https://github.com/nicholaschenai/grokking-tickets-lambda/blob/devinterp/report.pdf">Full Report</a>]
            
            <h3>Intro</h3>
            <ul>
                <li>Motivation: AI models display emergent properties when scaled. While some of these properties are useful, like chain-of-thought reasoning or in-context learning, some of them are undesired (e.g. <a target="_blank" href="https://arxiv.org/abs/2201.03544">emergent reward hacking</a>). Studying such emergence, especially when they are in the form of sharp phase transitions, is important. Furthermore, studying phase transitions is important because it is inherent to composition, a core component of reasoning.</li>
                <li> <a target="_blank" href="https://arxiv.org/abs/2301.05217">Nanda <em>et. al.</em> (2023)</a> investigated the phenomenon of grokking (test loss increases / plateaus initially while the train loss decreases, then suddenly drops long after the train loss has already stabilized, example in the below figure) for a transformer trained on the modulo addition task by reverse engineering the network from its weights (mechanistic interpretability) and finding progress measures that track the underlying pheonomena</li>
                <p>
                    <img className="img-fluid w-100 mt-3" src={require("../img/nanda_progress_measures.png")}/>
                    Image from Nanda <em>et. al.</em>. Example of grokking. Excluded and restricted loss are the progress measures identified that track the underlying phenomena.
                </p>
                <li>Excluded loss: Loss after removing important components in the logits that are required for generalization (so we expect the excluded loss to be high when the network generalizes)</li>
                <li>Restricted loss: Loss after using only the important components in the logits that are required for generalization (so we expect the restricted loss to decrease sharply when the network generalizes)</li>

                <li>They showed that the transformer goes through these phases during grokking: 
                    <ul>
                        <li>Memorization: Train loss decreases sharply at the beginning. Excluded loss also decreases since the network is memorizing and not generalizing.</li>
                        <li>Circuit formation: Though the train and test losses plateaus, the excluded loss increases, showing that the network is starting to generalize</li>
                        <li>Cleanup: Test loss, restricted loss and sum of squared weights sharply drops, showing that the model generalizes while the memorized solution is pruned away in the weights via weight decay</li>
                    </ul>
                </li>
                <li> <strong>The big question:</strong> How is it possible that neural networks, which are fundamentally <em>continuous</em>, are able to learn <em>discrete</em> algorithms? These discrete algorithms contain multiple components which are useless individually, so where did gradient descent get the signal to uncover these structures?</li> 
                <li>They hypothesized that it might be due to the <strong>lottery ticket hypothesis</strong> (for any network, we can find a subnetwork and train it to be as good as the full network): That early on, the neural networks are a superposition of <em>circuits</em>, and gradient descent slowly boosts the useful circuits until at least one of them develops sufficiently, which causes the other circuits to be relevant and thus gradient descent boosts all of them sharply.</li>
                <li>Recently, <a target="_blank" href="https://arxiv.org/abs/2310.19470">Minegishi <em>et. al.</em> (2023)</a> showed that lottery tickets (subnetworks whose connectivity is derived from the trained network but their weights are reset to the untrained value) derived from the network after generalization (hence termed 'grokking tickets') generalize faster than the base model when trained, while lottery tickets derived at earlier stages of the network tend to generalize slower than the base model or fail to generalize at all! This suggests that the structure of the network is important, and supports the hypothesis above!
                </li>

            </ul>

            <h3>Contributions</h3>
            <ul>
                <li>We reproduce the core result of Minegishi <em>et. al.</em> (2023), and confirm that the later the stage of grokking where the lottery ticket is derived from, the faster the lottery ticket generalizes.</li>
                <li>We also show that the final test loss improves with the stage it is derived from, and only the grokking ticket has a test loss lower than the base model in Nanda <em>et. al.</em></li>
                <li>We show that lottery tickets derived at different stages of grokking exhibit different behavior (such as immediate generalization!), and these can be discerned through progress measures which we track! This further supports the lottery ticket hypothesis in explaining phase transitions</li>
                <li>We also calculate the <a target="_blank" href="https://arxiv.org/abs/2308.12108">RLCT</a>, an invariant of the neural network, and show that it is also able to discern different behaviors of lottery tickets, and even lead the occurence of grokking. This is important as all the progress measures (except the sum of squared weights) are specific to the modulo addition task, whereas the RLCT is generic, and all the progress measures do not predict grokking ahead of time</li>
            </ul>

            <h3>Approach</h3>
            <ul>
                <li>Obtain the network weights in Nanda <em>et. al.</em> (2023) at the various stages: Initialization, (end of)Memorization, (end of) Circuit formation, (end of) Cleanup. Only the last one is qualified as a 'grokking ticket'</li>
                <li>For each of the networks above (excluding initialization), derive a mask which prunes away the bottom 40% of weights by magnitude.</li>
                <li>For each mask derived, apply it to the network at <strong>initialization</strong> and train it via the approach in Nanda <em>et. al.</em> (2023), obtaining progress measures.</li>
                <li>Hyperparameter search for SGLD, then estimate RLCT for each network over the training epochs.</li>
            </ul>
            <h3>Grokking Tickets Speed Up Generalization</h3>
            <p>
                <img className="img-fluid w-100 mt-3" src={require("../img/loss_comparison_test.png")}/>
                Test losses for various lottery tickets compared to the base model in Nanda <em>et. al.</em>
            </p>
            
            <ul>
                <li>All lottery tickets generalize faster than the base model, with generalization speeds improving with the later the stage of derivation</li>
                <li>The final test loss improves with the stage of grokking the lottery ticket was derived from, and only the grokking ticket (end of cleanup) test loss is better than the base test loss</li>
                <li>We observe the standard grokking curve in the memorization lottery ticket, but immediate generalization in the circuit formation and cleanup lottery ticket</li>
                <li>These observations suggest that generalization is a search over network structures: initializing a network with good structure places it ahead in the optimization landscape, but fixing it to a suboptimal structure (i.e. the non-grokking tickets) hurts the final performance as it cannot transit to the optimal one</li>
            </ul>
            <h3>Progress Measures and RLCT Explain Phases on Networks Trained From The Memorization Lottery Ticket</h3>
            <img className="img-fluid w-100 mt-3" src={require("../img/excluded_loss_mean_memorization.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/restricted_loss_memorization.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/sum_sq_weight_total_memorization.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/rlct_memorization.png")}/>
            <p>
                <ul>
                    <li>Standard behavior of the base model reproduced (discussed in the intro), consistent with Nanda <em>et. al.</em> (2023). We see stages of grokking, tracked by progress measures and RLCT.</li>
                    <li>RLCT mean loss exhibits 3 phases that explain grokking: sharp decline over the memorization phase, slow decline over the circuit formation phase, then sharp decline in a way that <strong>leads</strong> the cleanup phase.</li>
                </ul>
                
            </p>
            
            <h3>Progress Measures and RLCT Reflect Immediate Generalization On Networks Trained From The Circuit Formation / Cleanup Lottery Ticket</h3>
            <img className="img-fluid w-100 mt-3" src={require("../img/excluded_loss_mean_cleanup.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/restricted_loss_cleanup.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/sum_sq_weight_total_cleanup.png")}/>
            <img className="img-fluid w-100 mt-3" src={require("../img/rlct_cleanup.png")}/>
            <p>
                <ul>
                    <li>Graphs for circuit formation lottery ticket omitted as it is similar to cleanup lottery ticket</li>
                    <li>Networks generalize immediately, showing the advantage of initializing with good structure!</li>
                    <li>The circuit formation phase is characterized by the increase in excluded loss: In this case, we see that this increase only happens for a short period of time, suggesting that the network already forms circuits early on</li>
                    <li>Restricted loss decreases immediately and sharply till below the train loss, indicative of the cleanup phase starting immediately during training</li>
                    <li>Total sum of squared weights only sees 1 sharp decline instead of 2, again reinforcing that training is accelerated.</li>
                    <li>RLCT now declines smoothly at a decreasing rate (instead of having 3 phases previously), plateauing around the same time when test loss stops dropping sharply, showing that it distinguishes between the immediate generalization behavior from the grokking behavior in the previous section</li>
                    <li>Giving the network an advantageous structure and seeing the immediate generalization and acceleration in grokking phases (as seen from progress measures) strongly supports Nanda <em>et. al.</em>'s hypothesis that once a component is formed, the other components rapidly form as they become important in the context of the first component!</li>
                </ul>
            </p>
            
        </>
    )
}
export default ProjGrok;