Transformers-based Generative AI opens enormous collaboration opportunities for business — but also poses risks of exposing IP. Here's how we are keeping all parties secure for the future of AI.
Transformers-based Generative AI opens enormous collaboration opportunities for business — but also poses risks of exposing IP. Here's how we are keeping all parties secure for the future of AI.
Generative AI applications like OpenAI’s ChatGPT, GitHub’s Copilot, or Stability’s Stable Diffusion are getting a lot of well-deserved attention these days. But despite the hype, many major enterprises and government organizations may be holding back from adopting them – in large part due to real security concerns.
Built using a type of machine learning architecture called the transformer, generative AI models are trained by ingesting massive amounts of data from within an organization or from the open internet. When it comes time to use the model, users input “prompts” to receive generated content back from the model. Many times, that training data and prompt data is sensitive and/or proprietary — as evidenced by generative AI’s massive and growing intellectual property problem. Meanwhile, even willing AI participants can’t always be assured of full data security. For just one example, some ChatGPT users’ conversations were partially publicly exposed earlier this year, and that’s one of the reasons the Italian government temporarily banned the service.
Businesses, for their part, face a real generative AI dilemma: embrace AI and take real security risks, or stay on the sidelines and lose out on innovation and an AI competitive edge. Given the risks, many companies have decided to limit their generative AI use. Samsung implemented an internal ban on staff using ChatGPT after an employee accidentally uploaded sensitive code with the system. The concern has spread to companies including Apple, Bank of America, and JPMorgan, which have all placed restrictions of their own on ChatGPT use.
And to be clear, the risks and the opportunities are two-way streets. Both enterprises and generative AI providers stand to gain much – and risk much too – from greater generative AI use. These risks exist both at the level of model training (“teaching” the model to understand the domain and knowledge using an initial dataset) and model inference (using the model to generate content for the model's user).
For instance, say a workplace communications company like Slack or Zoom wants to use its customers’ conversation records—sensitive information its customers have entrusted it with —to fine-tune a large language model (LLM). If the organization has the vast resources – a team of AI experts on staff and literally hundreds of millions of dollars to commit to the project – then it can build (train) its own AI on its own data from scratch. But for most companies that’s obviously not an option – and it will need to train its data on top of another company’s foundation model. It will also need to host the new AI model somewhere for model inference (aka applying that model for the use case).
Both training and inference require either (a) moving the organization’s customers’ conversation data to the model provider’s cloud instance, or (b) moving the model provider’s proprietary model (weights and parameters) to the organization’s cloud. Neither of the two options is ideal, as it exposes proprietary and sensitive data to other organizations - but unfortunately, this is the status quo currently. The obvious question is: Is there a way to train and consume generative AI without exposing highly sensitive information?
The answer is that yes, there is a way to combine the assets in a fully secure way – through secure multi-party computation (SMPC), an encryption technique designed for running computations while preserving the privacy of all parties involved.
Our latest product, Pyte’s SecureAI, applies transformer-based models and uses our proprietary SMPC protocol, keeping security front and center. This enables organizations to train and predict on their generative AI models, either on their or their model provider’s cloud, without having to sacrifice the privacy of proprietary datasets. To understand the technical details behind how we did it, keep reading — otherwise, jump to the end to see our conclusions.
Secure multi-party computation (SMPC) is one of the most popular privacy-enhancing technologies (PETs). In a nutshell, it allows several parties with sensitive datasets to collaborate on a computational task without leaking any information about their respective datasets to each other. For a high-level overview, check out a Wikipedia article or a survey GitHub repository. We find the theory behind SMPC to be quite elegant, since it allows for any computation to be conducted in a privacy-preserving way, which cannot be said about some other PETs. Before diving in, let’s go through a quick overview on how SMPC technology works. SMPC is effective when there are multiple organizations that want to collaborate on datasets together. This can be cross-company collaboration, or within a company (across department or geography), and many other cases. The datasets are transformed in each organization’s local environment to a format called secret shares.
Secret share refers to cryptographic techniques that split a secret into multiple shares and distribute them amongst parties in a way that the secret can only be reconstructed when a sufficient number of parties collaborate.
There are a number of different SMPC protocols for secret shares that are available in the market and academia, but the one that Pyte uses is called ABY3. ABY3 can be found in Pyte’s CipherCore (open source) and SecureAI products.
Secret shares in ABY3 work as follows. Let’s say we have one bit [x], represented as a sum of three parts: [x] = (x_1 + x_2 + x_3) % 2, where [x_1], [x_2] and [x_3] are chosen randomly. Each party knows only two out of three variables [x_1], [x_2], [x_3]; thus, no party can reconstruct the exact value of [x].
If we have two bits [x] and [y] stored in this form, we can compute the representation of [z] = ([x] + [y]) % 2 very quickly: [z_i] = [x_i] + [y_i]. So every party can compute new values without sending information to other parties.
Computing [x] * [y] is more complicated and requires sending messages over the network as well as evaluating certain cryptographic primitives.
We can use addition and multiplication operations as simple building blocks for arbitrarily complex computations - these complex computations can be represented as graphs. Below is an example of a simple computation (compute a product of two 5x5 matrices), and the graph for the corresponding secure protocol.
It is very important to optimize the size and the depth of the resulting graph of a secure protocol. Size corresponds to a total compute, and the depth corresponds to network latency.
Now that we have a broad understanding of SMPC and how its computations work, let’s discuss how we use SMPC protocols to perform machine learning training and inference, specifically for the transformer architecture.
In order to perform training or inference of a transformer in SMPC, we first need to represent the training/inference algorithms as graphs suitable for SMPC compilation. For this we take the foundation model in the ONNX format, and convert it to the CipherCore graph format. For inference, we encode the forward pass, for training/fine-tuning, we encode both forward and backward pass together with the required optimizer, which is typically Adam for transformers.
Although this sounds simple in principle, the devil is in the details. If we perform the above compilation process naively, we will end up with an SMPC protocol incurring a huge computational overhead. Thus, it’s important to do it as optimally as possible. Below, we describe a few of the optimizations that we do.
Executing operations in parallel
It’s not hard to see that computational overhead might be a serious issue for SMPC protocols. For example, to compare two n-bit numbers stored in secret-shared form, we can build a circuit that requires n round-trips between parties.
Usually, the latency between parties located in the same datacenter is about 1 ms, so comparing two 64-bit integers could require tens of milliseconds. If comparing two numbers takes so long, how can we run a more complex computation in SMPC?
One possible optimization is to do a lot of computations in parallel! Usually, ML models require a lot of operations in total, but often you don’t need to run them sequentially. For example, if you want to multiply two (n x n)-matrices, you can compute each value of the resulting matrix in parallel. In total, it requires O(n^3) local operations and O(n^2) values sent over the network, but all of them could be sent in parallel, which means only O(1) round trips.
One major difficulty in running practical computations in SMPC is the necessity to deal with floating point numbers.
By default, the ABY3 protocol (as well as any other SMPC protocol) is used to do computations over bits or integers. But in machine learning, most calculations are done in floating point numbers. Usually, the floating-point number [f] is stored as two integers [s] and [exp] such that: [f] = [s] * 2^[exp]. Such a representation allows us to store both very big and very small numbers.
Unfortunately, if we want to add two numbers stored in this format in SMPC, it would require comparing exponents of the numbers, which is expensive. Instead, fixed-point arithmetic is usually used where [exp] is the same for all numbers. Adding two numbers in fixed-point arithmetic is the same as adding two integers. Multiplication of such numbers is the same as multiplication of integers and then dividing by 2^[exp], which is possible to do efficiently.
We used [exp]=-15 in our experiments, which gave precise enough values. Using smaller [exp] could provide better precision, but it also means the largest representable number will be smaller, which can lead to overflows. We used 64-bit numbers to store [s], but if higher precision is needed, it is possible to use 128-bit numbers (at the cost of 2x performance overhead).
Transformer models use the following Softmax operation inside:
The exponents here could grow very fast. To make this formula computationally stable, it is helpful to find the maximum of inputs and subtract it from all inputs (it doesn’t change the resulting value, but the absolute value of the denominator is much smaller). Finding the maximum requires comparing numbers, which involves many network round-trips.
So instead of using Softmax in the model, we replaced it with a more SMPC-friendly function. Instead of calculating exp(z_i), we can use ReLU(z_i) in this formula. Experiments show that the quality of the results is the same after the replacement, but training takes roughly 1.5x longer.
Some important functions in machine learning are not SMPC-friendly as well. For example, the sigmoid function is hard to compute efficiently in SMPC if we just calculate it by its definition:
Luckily, we can instead build a piecewise-linear approximation of such a function. In our experiments, we split all reasonable [x] into 32 buckets. Each bucket employs separate linear approximations of this function.
In the case of using 8 buckets, the difference is still significant:
But for 32 buckets, the difference becomes negligible:
Such piecewise-linear approximations work quickly and show results very close to exact computations. We also use the same technique for other functions, e.g., GeLU and its derivative.
Let’s briefly discuss how to do some basic operations with real numbers. For example, if we have two numbers [a] and [b], how to compute [a] divided by [b] efficiently?
We can use iterative methods with a fixed number of iterations. For example, we can use Newton–Raphson division algorithm or Goldschmidt division: https://en.wikipedia.org/wiki/Division_algorithm
Note that it is essential to ensure no overflows are happening during the computations.
One challenging piece of the training algorithm is the Adam optimizer. It calls an inverse square root function inside, which is computed with the Newton-Raphson iterative algorithm similar to the one we employed for division.
One tricky moment about optimizers is choosing the learning rate. The golden rule is to set it to 3*10^-4, which worked fine in our tests. As mentioned, we use fixed-precision numbers, with the smallest representable number equal to 2^-15=3*10^-5. So if we needed to use a 10 times lower learning rate, it would be on the boundary of possible precision, and we would need to change fixed [exp].
Another issue is using the abovementioned approximations for various nonlinear operations. Unfortunately, if we replace some pieces in the model with some approximations, we can’t just use the existing pre-trained weights of a model. Instead, we must fine-tune (or retrain from scratch) the model before using it.
To generate each output token, we need to run an inference on the graph, which takes all previous tokens as inputs. The default way of doing it is constructing a new graph with different input sizes on each iteration. When using a model in SMPC, there is also a step of converting graphs to the SMPC format, which could take even more time than the model evaluation itself! However, this is not a problem because you can convert the graph once and then run a lot of computations.
For this specific case, we didn’t want to regenerate the graph every time, so we modified the initial graph to have a fixed input length and also took a mask, representing whether a particular token is already computed.
Before generating each token of the text output, the model produces a long array of real numbers — the probability of using each possible token on the next position. Later, we must choose one index of that array according to the probabilities. If the array is big (tens of thousands of elements), this part could be pretty slow in SMPC.
Instead of doing this part in SMPC, we return this array in a raw form to the user (but not to the model owner), and the user does the sampling in plaintext.
We trained a 30M parameter (a GPT2-like) model on tinystories dataset with Softmax replaced with an SMPC-friendly version. We did the training in plaintext. After that, we converted the model to an SMPC graph and were able to run inference. You can see an example of this model generating a story based on a prompt (this video is not real-time, rather it is 4x speed):
We sympathize, Tim.
The model is running in the decoder mode, which means it can only generate the next token when the previous token is known. Generating one token takes roughly 2s. It is worth noting that the main bottleneck is the network communication, not actual CPU computations. If we decided to use a 10x bigger model, computing one token would still take the same time, effectively making scaling free from a computation point of view.
Generative AI has already transformed industries and business models, and is poised to do much more than this in the near future. However, for generative AI to be adopted at scale at enterprise and government level, privacy and security must be central to the story. We believe that secure multi-party computation (SMPC) is a prime cryptographic technique that can help solve for Generative AI’s security problems.
We have already enabled for interesting generative AI models trained on tinystories, to be fine-tuned and predicted on using SMPC. Granted, the number of parameters between tinystories and GPT-4/Llama 2/PaLM still has a difference of a couple of magnitudes; but we believe this can be overcome in the future.
One direction is a better use of GPUs as opposed to CPUs.
Our SecureAI product is an optimal solution for securing a highly powerful generative AI model. To learn how SecureAI can secure your generative AI use case, talk to Pyte.
[1] If you want to learn more about SMPC and Aby3 protocol: ABY3: A Mixed Protocol Framework for Machine Learning.
There is an ongoing research on privacy-preserving transformers. If you are interested in this idea, you can read what others are doing in this field:
Our latest funding milestone will enable us to expand into highly regulated sectors.
The latest funding will accelerate the commercialization of Pyte’s secure computation tech for data utilization and collaboration
Standard access management is not enough to protect data. Snowflake's recent hack is just another example.