ThreadLocal是怎么实现线程隔离的

ThreadLocal大家应该都不陌生,见过最多的使用场景应该是和SimpleDateFormat一起使用吧,因为这个SDF非线程安全的,所以需要使用ThreadLocal将它在线程之间隔离开,避免造成脏数据的🐞。那么ThreadLocal是怎么保证线程安全,又是如何操作的呢?

案例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public static void main(String[] args) {

ThreadLocal<Integer> threadLocal = new ThreadLocal<>();

new Thread(new Runnable() {
@Override
public void run() {
threadLocal.set(1);
threadLocal.set(2);
System.out.println("cc1: " + threadLocal.get());
}
}, "cc1").start();

new Thread(new Runnable() {
@Override
public void run() {
System.out.println("cc2: " + threadLocal.get());
}
}, "cc2").start();

}

输出:

1
2
cc1: 2
cc2: null

哦哟~cc2打印出来null,也就是在cc1线程中设置的值在线程cc2中获取不到,这也就是所谓的线程隔离,我们来看下ThreadLocal具体的代码实现吧:

ThreadLocal的set(T t)方法源码

1
2
3
4
5
6
7
8
9
10
11
12
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取当前线程的threadLocals属性,这个属性在Thread类中定义的,为Thread的实例变量
ThreadLocalMap map = getMap(t);
// 若线程的ThreadLocalMap已经存在,则调用ThreadLocalMap的set(ThreadLocal<T> key, Object value)方法
// 否则创建新的ThreadLocalMap实例,并set对应的value
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

ThreadLocalMap的set(ThreadLocal key, Object value)方法源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private void set(ThreadLocal<?> key, Object value) {

Entry[] tab = table;
int len = tab.length;
// 简单计算key所在的位置
int i = key.threadLocalHashCode & (len-1);
// 从key所在位置开始遍历table数组,找到具体key所在的位置
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
// 获取Entry实例的key值,这里调用的是超类java.lang.ref.Reference中的get(T t)方法
ThreadLocal<?> k = e.get();
// 若k与传入的参数key是同一个,则用参数value替换Entry实例的value,然后结束方法
if (k == key) {
e.value = value;
return;
}
// 若获取的k为null,则表示这个变量已经被删除了,则去清理一下table数组,并对数组中元素进行清理并设置新的Entry实例
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 代码走到这一步,说明该线程第一次设置数据,创建新的Entry实例放在table的第i个位置上
tab[i] = new Entry(key, value);
int sz = ++size;
// 清理table中的元素,若长度达到了扩容阈值,则对table进行扩容,扩容为原数组长度的2倍
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

ThreadLocal的createMap(Thread t, T firstValue)方法源码

1
2
3
4
5
void createMap(Thread t, T firstValue) {
// 创建一个ThreadLocalMap实例,并赋值给当前线程的实例变量threadLocals
// 这里就是线程隔离的关键所在,每一个线程中的数据都是由线程独有的threadLocals变量存储的
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap的构造器源码

1
2
3
4
5
6
7
8
9
10
11
12
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 实例化Entry数组,长度为初始长度16
table = new Entry[INITIAL_CAPACITY];
// 计算key在数组中的位置
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 创建Entry实例,并放在table的i下标位置
table[i] = new Entry(firstKey, firstValue);
// 实际长度设置为1
size = 1;
// 设置数组扩容阈值(len * 2 / 3)
setThreshold(INITIAL_CAPACITY);
}

以上便是ThreadLocal达到线程隔离的基本解析,讲解的比较基础,其实就是JDK源码鉴赏,还有什么不懂的地方就自己去看源码吧。

延伸下

ThreadLocal的get()方法源码

1
2
3
4
5
6
7
8
9
10
11
12
13
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

这段代码比较简单,这里就不在进行解释了,我们着重看一下最后一句setInitialValue()这个方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

protected T initialValue() {
return null;
}

会发现和set方法类似,只不过是将一个null当做value而已,所以我们在没给ThreadLocal设置值的情况下调用get方法,则会为其创建一个默认的null值并返回null。

留一个思考题

因为我们每个线程的ThreadLocal的key的hash值都是固定的,那么Thread的threadLocals变量的table中会有多少个非null元素呢?