Private Tensor

Import dependencies


In [1]:
import syft as sy
import torch as th

Init Hook


In [2]:
hook = sy.TorchHook(th)

Create a worker reference

It works with Virtual Workers, PySyft Websocket Workers and also PyGrid Nodes.


In [3]:
node1 = sy.VirtualWorker(hook, id="Node1")

Sending a Private Tensor to worker


In [4]:
## Create a tensor
x = th.tensor([1,2,3,4,5,6]).tag("#cancer")

## Convert to Private Tensor
x_private = x.private_tensor(allowed_users=["Hospital1", "Hospital2"]) # It can be another data types, not only strings
x_private


Out[4]:
(Wrapper)>PrivateTensor>tensor([1, 2, 3, 4, 5, 6])
	Tags: #cancer 
	Shape: torch.Size([6])

In [5]:
x_private_pointer = x_private.send(node1)
x_private_pointer


Out[5]:
(Wrapper)>[PointerTensor | me:5393355131 -> Node1:62449623521]

Perform Get Method (Disallowed credentials)


In [6]:
x_private_pointer.get()


---------------------------------------------------------------------------
GetNotPermittedError                      Traceback (most recent call last)
<ipython-input-6-66ac7b4e6c35> in <module>
----> 1 x_private_pointer.get()

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/frameworks/torch/tensors/interpreters/native.py in get(self, inplace, user, reason, *args, **kwargs)
    537             tensor = self.child.get(*args, **kwargs)
    538         else:  # Remote tensor/chain
--> 539             tensor = self.child.get(*args, user=user, reason=reason, **kwargs)
    540 
    541         # Clean the wrapper

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/generic/pointers/pointer_tensor.py in get(self, user, reason, deregister_ptr)
    296             object used to point to #on a remote machine.
    297         """
--> 298         tensor = ObjectPointer.get(self, user=user, reason=reason, deregister_ptr=deregister_ptr)
    299 
    300         # TODO: remove these 3 lines

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/generic/pointers/object_pointer.py in get(self, user, reason, deregister_ptr)
    266         else:
    267             # get tensor from location
--> 268             obj = self.owner.request_obj(self.id_at_location, self.location, user, reason)
    269 
    270         # Remove this pointer by default

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/base.py in request_obj(self, obj_id, location, user, reason)
    606             A torch Tensor or Variable object.
    607         """
--> 608         obj = self.send_msg(ObjectRequestMessage((obj_id, user, reason)), location)
    609         return obj
    610 

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/base.py in send_msg(self, message, location)
    266 
    267         # Step 2: send the message and wait for a response
--> 268         bin_response = self._send_msg(bin_message, location)
    269 
    270         # Step 3: deserialize the response

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/virtual.py in _send_msg(self, message, location)
      5 class VirtualWorker(BaseWorker, FederatedClient):
      6     def _send_msg(self, message: bin, location: BaseWorker) -> bin:
----> 7         return location._recv_msg(message)
      8 
      9     def _recv_msg(self, message: bin) -> bin:

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/virtual.py in _recv_msg(self, message)
      8 
      9     def _recv_msg(self, message: bin) -> bin:
---> 10         return self.recv_msg(message)

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/base.py in recv_msg(self, bin_message)
    300             print(f"worker {self} received {sy.codes.code2MSGTYPE[msg_type]} {contents}")
    301         # Step 1: route message to appropriate function
--> 302         response = self._message_router[msg_type](contents)
    303 
    304         # Step 2: Serialize the message to simple python objects

~/workspace/v/lib/python3.7/site-packages/syft-0.2.0a2-py3.7.egg/syft/workers/base.py in respond_to_obj_req(self, request_obj)
    549         obj = self.get_obj(obj_id)
    550         if hasattr(obj, "allowed_to_get") and not obj.allowed_to_get(user):
--> 551             raise GetNotPermittedError()
    552         else:
    553             self.de_register_obj(obj)

GetNotPermittedError: 

Perform Get Method (Allowed Credentials)


In [13]:
x_private_pointer.get(user="Hospital1", reason="I'm the owner!")


Out[13]:
(Wrapper)>PrivateTensor>tensor([1, 2, 3, 4, 5, 6])
	Tags: #cancer 
	Shape: torch.Size([6])