|
| 1 | +/* |
| 2 | +Copyright 2023 Hedgehog SONiC Foundation |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +*/ |
| 16 | + |
| 17 | +package tpm |
| 18 | + |
| 19 | +import ( |
| 20 | + "context" |
| 21 | + "errors" |
| 22 | + "fmt" |
| 23 | + "net" |
| 24 | + "os" |
| 25 | + "path/filepath" |
| 26 | + "time" |
| 27 | + |
| 28 | + "go.uber.org/zap" |
| 29 | + "google.golang.org/grpc" |
| 30 | + "google.golang.org/grpc/credentials/insecure" |
| 31 | + |
| 32 | + "go.githedgehog.com/k8s-tpm-device-plugin/internal/plugin" |
| 33 | + |
| 34 | + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" |
| 35 | +) |
| 36 | + |
| 37 | +const ( |
| 38 | + tpmID = "tpm0" |
| 39 | + tpmSocketName = "hh-tpm.sock" |
| 40 | +) |
| 41 | + |
| 42 | +var ( |
| 43 | + connectionTimeout = time.Second * 5 |
| 44 | + registerTimeout = time.Second * 30 |
| 45 | + errUnimplmented = errors.New("plugin does not implement this method") |
| 46 | +) |
| 47 | + |
| 48 | +func UnimplementedError(str string) error { |
| 49 | + return fmt.Errorf("%w: %s", errUnimplmented, str) |
| 50 | +} |
| 51 | + |
| 52 | +type tpmDevicePlugin struct { |
| 53 | + l *zap.Logger |
| 54 | + tctiEnvVar bool |
| 55 | + socketPath string |
| 56 | + server *grpc.Server |
| 57 | + stopCh chan struct{} |
| 58 | +} |
| 59 | + |
| 60 | +var _ plugin.Interface = &tpmDevicePlugin{} |
| 61 | +var _ pluginapi.DevicePluginServer = &tpmDevicePlugin{} |
| 62 | + |
| 63 | +func New(l *zap.Logger, tctiEnvVar bool) (plugin.Interface, error) { |
| 64 | + return &tpmDevicePlugin{ |
| 65 | + l: l.With(zap.String("plugin", "tpm")), |
| 66 | + tctiEnvVar: tctiEnvVar, |
| 67 | + socketPath: filepath.Join(pluginapi.DevicePluginPath, tpmSocketName), |
| 68 | + // will be initialized by Start() |
| 69 | + server: nil, |
| 70 | + stopCh: nil, |
| 71 | + }, nil |
| 72 | +} |
| 73 | + |
| 74 | +func (p *tpmDevicePlugin) init() { |
| 75 | + p.server = grpc.NewServer() |
| 76 | + p.stopCh = make(chan struct{}) |
| 77 | +} |
| 78 | + |
| 79 | +func (p *tpmDevicePlugin) cleanup() { |
| 80 | + close(p.stopCh) |
| 81 | + p.server = nil |
| 82 | + p.stopCh = nil |
| 83 | +} |
| 84 | + |
| 85 | +func (p *tpmDevicePlugin) Name() string { |
| 86 | + return "tpm" |
| 87 | +} |
| 88 | + |
| 89 | +// Start implements Interface |
| 90 | +func (p *tpmDevicePlugin) Start(ctx context.Context) error { |
| 91 | + // caller safeguard |
| 92 | + if p == nil { |
| 93 | + return nil |
| 94 | + } |
| 95 | + p.init() |
| 96 | + |
| 97 | + if err := p.Serve(ctx); err != nil { |
| 98 | + return err |
| 99 | + } |
| 100 | + p.l.Info("TPM Device Plugin server started") |
| 101 | + if err := p.Register(ctx); err != nil { |
| 102 | + return err |
| 103 | + } |
| 104 | + p.l.Info("TPM Device Plugin registered with kubelet") |
| 105 | + |
| 106 | + return nil |
| 107 | +} |
| 108 | + |
| 109 | +// Stop implements Interface |
| 110 | +func (p *tpmDevicePlugin) Stop(context.Context) error { |
| 111 | + // caller safeguard |
| 112 | + if p == nil || p.server == nil { |
| 113 | + return nil |
| 114 | + } |
| 115 | + p.l.Info("Stopping gRPC server", zap.String("socket", p.socketPath)) |
| 116 | + p.server.Stop() |
| 117 | + if err := os.Remove(p.socketPath); err != nil && !os.IsNotExist(err) { |
| 118 | + return fmt.Errorf("removing socket path %s: %w", p.socketPath, err) |
| 119 | + } |
| 120 | + p.cleanup() |
| 121 | + return nil |
| 122 | +} |
| 123 | + |
| 124 | +func (p *tpmDevicePlugin) Serve(ctx context.Context) error { |
| 125 | + // listen on unix socket |
| 126 | + // NOTE: no need to close the listener as the gRPC methods close the listener automatically |
| 127 | + if err := os.Remove(p.socketPath); err != nil && !os.IsNotExist(err) { |
| 128 | + return fmt.Errorf("removing socket path %s: %w", p.socketPath, err) |
| 129 | + } |
| 130 | + var lc net.ListenConfig |
| 131 | + l, err := lc.Listen(ctx, "unix", p.socketPath) |
| 132 | + if err != nil { |
| 133 | + return fmt.Errorf("listening on unix socket %s: %w", p.socketPath, err) |
| 134 | + } |
| 135 | + p.l.Info("Listening on unix socket for gRPC server now", zap.String("socket", p.socketPath)) |
| 136 | + |
| 137 | + // register the device plugin server API with the grpc server |
| 138 | + pluginapi.RegisterDevicePluginServer(p.server, p) |
| 139 | + |
| 140 | + // now run the gRPC server |
| 141 | + go func() { |
| 142 | + for { |
| 143 | + p.l.Info("Starting gRPC server now...") |
| 144 | + err := p.server.Serve(l) |
| 145 | + // err is nil when Stop() or GracefulStop() were called |
| 146 | + if err == nil { |
| 147 | + p.l.Info("Stopped gRPC server") |
| 148 | + return |
| 149 | + } |
| 150 | + p.l.Error("gRPC server crashed", zap.Error(err)) |
| 151 | + } |
| 152 | + }() |
| 153 | + |
| 154 | + // connect to the gRPC server in blocking mode to ensure it is up before we return here |
| 155 | + subCtx, cancel := context.WithTimeout(ctx, connectionTimeout) |
| 156 | + defer cancel() |
| 157 | + conn, err := grpc.DialContext(subCtx, "unix:"+p.socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) |
| 158 | + if err != nil { |
| 159 | + return fmt.Errorf("gRPC server did not start within timeout %v: %w", connectionTimeout, err) |
| 160 | + } |
| 161 | + conn.Close() // nolint: errcheck |
| 162 | + |
| 163 | + p.l.Info("Started gRPC server") |
| 164 | + return nil |
| 165 | +} |
| 166 | + |
| 167 | +func (p *tpmDevicePlugin) Register(ctx context.Context) error { |
| 168 | + // connect to kubelet socket |
| 169 | + connCtx, connCancel := context.WithTimeout(ctx, connectionTimeout) |
| 170 | + defer connCancel() |
| 171 | + conn, err := grpc.DialContext(connCtx, "unix:"+pluginapi.KubeletSocket, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) |
| 172 | + if err != nil { |
| 173 | + return fmt.Errorf("connecting to kubelet socket at %s: %w", pluginapi.KubeletSocket, err) |
| 174 | + } |
| 175 | + |
| 176 | + client := pluginapi.NewRegistrationClient(conn) |
| 177 | + |
| 178 | + regCtx, regCancel := context.WithTimeout(ctx, registerTimeout) |
| 179 | + defer regCancel() |
| 180 | + if _, err := client.Register(regCtx, &pluginapi.RegisterRequest{ |
| 181 | + Version: pluginapi.Version, |
| 182 | + Endpoint: tpmSocketName, |
| 183 | + ResourceName: "githedgehog.com/tpm", |
| 184 | + Options: &pluginapi.DevicePluginOptions{ |
| 185 | + PreStartRequired: false, |
| 186 | + GetPreferredAllocationAvailable: false, |
| 187 | + }, |
| 188 | + }); err != nil { |
| 189 | + return fmt.Errorf("gRPC register call: %w", err) |
| 190 | + } |
| 191 | + |
| 192 | + return nil |
| 193 | +} |
| 194 | + |
| 195 | +// Allocate implements v1beta1.DevicePluginServer |
| 196 | +func (p *tpmDevicePlugin) Allocate(_ context.Context, allocateRequest *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { |
| 197 | + p.l.Debug("Allocate() call", zap.Reflect("allocateRequest", allocateRequest)) |
| 198 | + resp := &pluginapi.AllocateResponse{} |
| 199 | + for _, req := range allocateRequest.ContainerRequests { |
| 200 | + p.l.Debug("allocate ContainerRequest", zap.Reflect("creq", req)) |
| 201 | + var envs map[string]string |
| 202 | + if p.tctiEnvVar { |
| 203 | + envs = map[string]string{ |
| 204 | + "TPM2TOOLS_TCTI": "device:/dev/tpm0", |
| 205 | + } |
| 206 | + } |
| 207 | + cresp := &pluginapi.ContainerAllocateResponse{ |
| 208 | + Envs: envs, |
| 209 | + Devices: []*pluginapi.DeviceSpec{ |
| 210 | + { |
| 211 | + ContainerPath: "/dev/tpm0", |
| 212 | + HostPath: "/dev/tpm0", |
| 213 | + Permissions: "rwm", |
| 214 | + }, |
| 215 | + }, |
| 216 | + } |
| 217 | + resp.ContainerResponses = append(resp.ContainerResponses, cresp) |
| 218 | + } |
| 219 | + return resp, nil |
| 220 | +} |
| 221 | + |
| 222 | +// GetDevicePluginOptions implements v1beta1.DevicePluginServer |
| 223 | +func (*tpmDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { |
| 224 | + return &pluginapi.DevicePluginOptions{ |
| 225 | + PreStartRequired: false, |
| 226 | + GetPreferredAllocationAvailable: false, |
| 227 | + }, nil |
| 228 | +} |
| 229 | + |
| 230 | +// GetPreferredAllocation implements v1beta1.DevicePluginServer |
| 231 | +func (p *tpmDevicePlugin) GetPreferredAllocation(_ context.Context, _ *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { |
| 232 | + p.l.Debug("GetPreferredAllocation() is unimplemented for this plugin") |
| 233 | + return nil, UnimplementedError("GetPreferredAllocation") |
| 234 | +} |
| 235 | + |
| 236 | +// ListAndWatch implements v1beta1.DevicePluginServer |
| 237 | +func (p *tpmDevicePlugin) ListAndWatch(_ *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { |
| 238 | + s.Send(&pluginapi.ListAndWatchResponse{Devices: []*pluginapi.Device{ |
| 239 | + { |
| 240 | + ID: tpmID, |
| 241 | + Health: pluginapi.Healthy, |
| 242 | + }, |
| 243 | + }}) |
| 244 | + |
| 245 | + // TODO: there is nothing we are doing at the moment to check if the TPM is healthy or not |
| 246 | + <-p.stopCh |
| 247 | + |
| 248 | + return nil |
| 249 | +} |
| 250 | + |
| 251 | +// PreStartContainer implements v1beta1.DevicePluginServer |
| 252 | +func (p *tpmDevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { |
| 253 | + p.l.Debug("PreStartContainer() is unimplemented for this plugin") |
| 254 | + return &pluginapi.PreStartContainerResponse{}, nil |
| 255 | +} |
0 commit comments