@@ -10,8 +10,7 @@ abstract type AddressSpace end
1010
1111module AS
1212
13- using CUDAnative
14- import CUDAnative: AddressSpace
13+ import .. AddressSpace
1514
1615struct Generic <: AddressSpace end
1716struct Global <: AddressSpace end
2625# Device pointer
2726#
2827
29- struct DevicePtr{T,A}
30- ptr :: Ptr{T }
28+ """
29+ DevicePtr{T,A }
3130
32- # inner constructors, fully parameterized
33- DevicePtr {T,A} (ptr :: Ptr{T} ) where {T,A <: AddressSpace } = new (ptr)
34- end
35-
36- # outer constructors, partially parameterized
37- DevicePtr {T} (ptr :: Ptr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
31+ A memory address that refers to data of type `T` that is accessible from the GPU. It is the
32+ on-device counterpart of `CUDAdrv.CuPtr`, additionally keeping track of the address space
33+ `A` where the data resides (shared, global, constant, etc). This information is used to
34+ provide optimized implementations of operations such as `unsafe_load` and `unsafe_store!.`
35+ """
36+ DevicePtr
3837
39- # outer constructors, non-parameterized
40- DevicePtr (ptr:: Ptr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
38+ if sizeof (Ptr{Cvoid}) == 8
39+ primitive type DevicePtr{T,A} 64 end
40+ else
41+ primitive type DevicePtr{T,A} 32 end
42+ end
4143
42- Base. show (io:: IO , dp:: DevicePtr{T,AS} ) where {T,AS} =
43- print (io, AS. name. name, " Device" , pointer (dp))
44+ # constructors
45+ DevicePtr {T,A} (x:: Union{Int,UInt,CuPtr,DevicePtr} ) where {T,A<: AddressSpace } = Base. bitcast (DevicePtr{T,A}, x)
46+ DevicePtr {T} (ptr:: CuPtr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
47+ DevicePtr (ptr:: CuPtr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
4448
4549
4650# # getters
4751
48- Base. pointer (p:: DevicePtr ) = p. ptr
49-
5052Base. eltype (:: Type{<:DevicePtr{T}} ) where {T} = T
5153
5254addrspace (x:: DevicePtr ) = addrspace (typeof (x))
@@ -55,20 +57,23 @@ addrspace(::Type{DevicePtr{T,A}}) where {T,A} = A
5557
5658# # conversions
5759
58- # between regular and device pointers
59- # # simple conversions disallowed
60- Base. convert (:: Type{Ptr{T}} , p:: DevicePtr{T} ) where {T} = throw (InexactError (:convert , Ptr{T}, p))
61- Base. convert (:: Type{<:DevicePtr{T}} , p:: Ptr{T} ) where {T} = throw (InexactError (:convert , DevicePtr{T}, p))
62- # # unsafe ones are allowed
63- Base. unsafe_convert (:: Type{Ptr{T}} , p:: DevicePtr{T} ) where {T} = pointer (p)
60+ # to and from integers
61+ # # pointer to integer
62+ Base. convert (:: Type{T} , x:: DevicePtr ) where {T<: Integer } = T (UInt (x))
63+ # # integer to pointer
64+ Base. convert (:: Type{DevicePtr{T,A}} , x:: Union{Int,UInt} ) where {T,A<: AddressSpace } = DevicePtr {T,A} (x)
65+ Int (x:: DevicePtr ) = Base. bitcast (Int, x)
66+ UInt (x:: DevicePtr ) = Base. bitcast (UInt, x)
6467
65- # defer conversions to DevicePtr to unsafe_convert
66- Base. cconvert (:: Type{<:DevicePtr} , x) = x
68+ # between host and device pointers
69+ Base. convert (:: Type{CuPtr{T}} , p:: DevicePtr ) where {T} = Base. bitcast (CuPtr{T}, p)
70+ Base. convert (:: Type{DevicePtr{T,A}} , p:: CuPtr ) where {T,A<: AddressSpace } = Base. bitcast (DevicePtr{T,A}, p)
71+ Base. convert (:: Type{DevicePtr{T}} , p:: CuPtr ) where {T} = Base. bitcast (DevicePtr{T,AS. Generic}, p)
6772
6873# between device pointers
69- Base. convert (:: Type{<:DevicePtr} , p:: DevicePtr ) = throw (InexactError ( : convert, DevicePtr, p ))
74+ Base. convert (:: Type{<:DevicePtr} , p:: DevicePtr ) = throw (ArgumentError ( " cannot convert between incompatible device pointer types " ))
7075Base. convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr{T,A} ) where {T,A} = p
71- Base. unsafe_convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr ) where {T,A} = DevicePtr {T,A} ( reinterpret (Ptr{T}, pointer (p)) )
76+ Base. unsafe_convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr ) where {T,A} = Base . bitcast ( DevicePtr{T,A}, p )
7277# # identical addrspaces
7378Base. convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr{U,A} ) where {T,U,A} = Base. unsafe_convert (DevicePtr{T,A}, p)
7479# # convert to & from generic
@@ -78,19 +83,25 @@ Base.convert(::Type{DevicePtr{T,AS.Generic}}, p::DevicePtr{T,AS.Generic}) where
7883# # unspecified, preserve source addrspace
7984Base. convert (:: Type{DevicePtr{T}} , p:: DevicePtr{U,A} ) where {T,U,A} = Base. unsafe_convert (DevicePtr{T,A}, p)
8085
86+ # defer conversions to DevicePtr to unsafe_convert
87+ Base. cconvert (:: Type{<:DevicePtr} , x) = x
88+
8189
8290# # limited pointer arithmetic & comparison
8391
84- Base.:(== )(a:: DevicePtr , b:: DevicePtr ) = pointer (a) == pointer (b) && addrspace (a) == addrspace (b)
92+ isequal (x:: DevicePtr , y:: DevicePtr ) = (x === y) && addrspace (x) == addrspace (y)
93+ isless (x:: DevicePtr{T,A} , y:: DevicePtr{T,A} ) where {T,A<: AddressSpace } = x < y
8594
86- Base. isless (x:: DevicePtr , y:: DevicePtr ) = Base. isless (pointer (x), pointer (y))
87- Base.:(- )(x:: DevicePtr , y:: DevicePtr ) = pointer (x) - pointer (y)
95+ Base.:(== )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) == UInt (y) && addrspace (x) == addrspace (y)
96+ Base.:(< )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) < UInt (y)
97+ Base.:(- )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) - UInt (y)
8898
89- Base.:(+ )(x:: DevicePtr{T,A} , y:: Integer ) where {T,A} = DevicePtr {T,A} ( pointer (x) + y )
90- Base.:(- )(x:: DevicePtr{T,A} , y:: Integer ) where {T,A} = DevicePtr {T,A} ( pointer (x) - y )
99+ Base.:(+ )(x:: DevicePtr , y:: Integer ) = oftype (x, Base . add_ptr ( UInt (x), (y % UInt) % UInt) )
100+ Base.:(- )(x:: DevicePtr , y:: Integer ) = oftype (x, Base . sub_ptr ( UInt (x), (y % UInt) % UInt) )
91101Base.:(+ )(x:: Integer , y:: DevicePtr ) = y + x
92102
93103
104+
94105# # memory operations
95106
96107Base. convert (:: Type{Int} , :: Type{AS.Generic} ) = 0
@@ -121,7 +132,7 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
121132 eltyp = convert (LLVMType, T)
122133
123134 T_int = convert (LLVMType, Int)
124- T_ptr = convert (LLVMType, Ptr{T })
135+ T_ptr = convert (LLVMType, DevicePtr{T,A })
125136
126137 T_actual_ptr = LLVM. PointerType (eltyp)
127138
@@ -148,15 +159,15 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
148159 ret! (builder, ld)
149160 end
150161
151- call_function (llvm_f, T, Tuple{Ptr{T }, Int}, :((pointer (p) , Int (i- one (i)))))
162+ call_function (llvm_f, T, Tuple{DevicePtr{T,A }, Int}, :((p , Int (i- one (i)))))
152163end
153164
154165@generated function Base. unsafe_store! (p:: DevicePtr{T,A} , x, i:: Integer = 1 ,
155166 :: Val{align} = Val (1 )) where {T,A,align}
156167 eltyp = convert (LLVMType, T)
157168
158169 T_int = convert (LLVMType, Int)
159- T_ptr = convert (LLVMType, Ptr{T })
170+ T_ptr = convert (LLVMType, DevicePtr{T,A })
160171
161172 T_actual_ptr = LLVM. PointerType (eltyp)
162173
184195 ret! (builder)
185196 end
186197
187- call_function (llvm_f, Cvoid, Tuple{Ptr{T}, T, Int}, :((pointer (p), convert (T,x), Int (i- one (i)))))
198+ call_function (llvm_f, Cvoid, Tuple{DevicePtr{T,A}, T, Int},
199+ :((p, convert (T,x), Int (i- one (i)))))
188200end
189201
190202# # loading through the texture cache
@@ -215,7 +227,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
215227
216228 T_int = convert (LLVMType, Int)
217229 T_int32 = LLVM. Int32Type (JuliaContext ())
218- T_ptr = convert (LLVMType, Ptr{T })
230+ T_ptr = convert (LLVMType, DevicePtr{T,AS . Global })
219231
220232 T_actual_ptr = LLVM. PointerType (eltyp)
221233 T_actual_ptr_as = LLVM. PointerType (eltyp, convert (Int, AS. Global))
@@ -258,7 +270,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
258270 ret! (builder, ld)
259271 end
260272
261- call_function (llvm_f, T, Tuple{Ptr{T }, Int}, :((pointer (p) , Int (i- one (i)))))
273+ call_function (llvm_f, T, Tuple{DevicePtr{T,AS . Global }, Int}, :((p , Int (i- one (i)))))
262274end
263275
264276@inline unsafe_cached_load (p:: DevicePtr{T,AS.Global} , i:: Integer = 1 , args... ) where {T} =
0 commit comments